[Mlir-commits] [mlir] 2abad34 - [mlir][rocdl] Adding vector to ROCDL dialect lowering

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 11 07:35:26 PDT 2020


Author: jerryyin
Date: 2020-06-11T14:28:13Z
New Revision: 2abad3433f9f48cb0a103726a9af1ad79603d23d

URL: https://github.com/llvm/llvm-project/commit/2abad3433f9f48cb0a103726a9af1ad79603d23d
DIFF: https://github.com/llvm/llvm-project/commit/2abad3433f9f48cb0a103726a9af1ad79603d23d.diff

LOG: [mlir][rocdl] Adding vector to ROCDL dialect lowering

* Created the vector to ROCDL lowering pass
  * The lowering pass lowers vector transferOps to rocdl mubufOps
* Added unit test and functional test

Added: 
    mlir/include/mlir/Conversion/VectorToROCDL/VectorToROCDL.h
    mlir/lib/Conversion/VectorToROCDL/CMakeLists.txt
    mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
    mlir/test/Conversion/VectorToROCDL/vector-to-rocdl.mlir
    mlir/test/mlir-rocm-runner/vector-transferops.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/InitAllPasses.h
    mlir/lib/Conversion/CMakeLists.txt
    mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
    mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 9b9caa46feda..d74e419dbb1a 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -301,4 +301,14 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
   let constructor = "mlir::createConvertVectorToLLVMPass()";
 }
 
+//===----------------------------------------------------------------------===//
+// VectorToROCDL
+//===----------------------------------------------------------------------===//
+
+def ConvertVectorToROCDL : Pass<"convert-vector-to-rocdl", "ModuleOp"> {
+  let summary = "Lower the operations from the vector dialect into the ROCDL "
+                "dialect";
+  let constructor = "mlir::createConvertVectorToROCDLPass()";
+}
+
 #endif // MLIR_CONVERSION_PASSES

diff  --git a/mlir/include/mlir/Conversion/VectorToROCDL/VectorToROCDL.h b/mlir/include/mlir/Conversion/VectorToROCDL/VectorToROCDL.h
new file mode 100644
index 000000000000..660de02ee36f
--- /dev/null
+++ b/mlir/include/mlir/Conversion/VectorToROCDL/VectorToROCDL.h
@@ -0,0 +1,28 @@
+//===- VectorToROCDL.h - Convert Vector to ROCDL dialect ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_VECTORTOROCDL_VECTORTOROCDL_H_
+#define MLIR_CONVERSION_VECTORTOROCDL_VECTORTOROCDL_H_
+
+#include <memory>
+
+namespace mlir {
+class LLVMTypeConverter;
+class OwningRewritePatternList;
+class ModuleOp;
+template <typename OpT>
+class OperationPass;
+
+/// Collect a set of patterns to convert from the GPU dialect to ROCDL.
+void populateVectorToROCDLConversionPatterns(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+
+/// Create a pass to convert vector operations to the ROCDL dialect.
+std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToROCDLPass();
+
+} // namespace mlir
+#endif // MLIR_CONVERSION_VECTORTOROCDL_VECTORTOROCDL_H_

diff  --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index f1b78f929664..ca5c8e0dac46 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -30,6 +30,7 @@
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
 #include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/Dialect/GPU/Passes.h"

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e60a1dfc0e32..698dbd269b8e 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -14,5 +14,6 @@ add_subdirectory(ShapeToStandard)
 add_subdirectory(SPIRVToLLVM)
 add_subdirectory(StandardToLLVM)
 add_subdirectory(StandardToSPIRV)
+add_subdirectory(VectorToROCDL)
 add_subdirectory(VectorToLLVM)
 add_subdirectory(VectorToSCF)

diff  --git a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
index 38b6e1a5ea4f..3e7294d54ac2 100644
--- a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
@@ -15,4 +15,5 @@ add_mlir_conversion_library(MLIRGPUtoROCDLTransforms
   MLIRROCDLIR
   MLIRPass
   MLIRStandardToLLVM
+  MLIRVectorToROCDL
   )

diff  --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index f81c382152d7..5707075767ed 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -15,6 +15,7 @@
 
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/GPU/Passes.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -55,6 +56,7 @@ class LowerGpuOpsToROCDLOpsPass
     patterns.clear();
 
     populateVectorToLLVMConversionPatterns(converter, patterns);
+    populateVectorToROCDLConversionPatterns(converter, patterns);
     populateStdToLLVMConversionPatterns(converter, patterns);
     populateGpuToROCDLConversionPatterns(converter, patterns);
     LLVMConversionTarget target(getContext());

diff  --git a/mlir/lib/Conversion/VectorToROCDL/CMakeLists.txt b/mlir/lib/Conversion/VectorToROCDL/CMakeLists.txt
new file mode 100644
index 000000000000..5cc1f5de2bf0
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToROCDL/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRVectorToROCDL
+  VectorToROCDL.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToROCDL
+
+  DEPENDS
+  MLIRConversionPassIncGen
+  intrinsics_gen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRROCDLIR
+  MLIRStandardToLLVM
+  MLIRVector
+  MLIRTransforms
+  )

diff  --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
new file mode 100644
index 000000000000..a1d483be78f3
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -0,0 +1,195 @@
+//===- VectorToROCDL.cpp - Vector to ROCDL lowering passes ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to generate ROCDLIR operations for higher-level
+// Vector operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
+
+#include "../PassDetail.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+static TransferReadOpOperandAdaptor
+getTransferOpAdapter(TransferReadOp xferOp, ArrayRef<Value> operands) {
+  return OperandAdaptor<TransferReadOp>(operands);
+}
+
+static TransferWriteOpOperandAdaptor
+getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef<Value> operands) {
+  return OperandAdaptor<TransferWriteOp>(operands);
+}
+
+static LogicalResult replaceTransferOpWithMubuf(
+    ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
+    LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp,
+    LLVM::LLVMType &vecTy, Value &dwordConfig, Value &vindex,
+    Value &offsetSizeInBytes, Value &glc, Value &slc) {
+  rewriter.replaceOpWithNewOp<ROCDL::MubufLoadOp>(
+      xferOp, vecTy, dwordConfig, vindex, offsetSizeInBytes, glc, slc);
+  return success();
+}
+
+static LogicalResult replaceTransferOpWithMubuf(
+    ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
+    LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp,
+    LLVM::LLVMType &vecTy, Value &dwordConfig, Value &vindex,
+    Value &offsetSizeInBytes, Value &glc, Value &slc) {
+  auto adaptor = TransferWriteOpOperandAdaptor(operands);
+  rewriter.replaceOpWithNewOp<ROCDL::MubufStoreOp>(xferOp, adaptor.vector(),
+                                                   dwordConfig, vindex,
+                                                   offsetSizeInBytes, glc, slc);
+  return success();
+}
+
+namespace {
+/// Conversion pattern that converts a 1-D vector transfer read/write.
+/// Note that this conversion pass only converts vector x2 or x4 f32
+/// types. For unsupported cases, they will fall back to the vector to
+/// llvm conversion pattern.
+template <typename ConcreteOp>
+class VectorTransferConversion : public ConvertToLLVMPattern {
+public:
+  explicit VectorTransferConversion(MLIRContext *context,
+                                    LLVMTypeConverter &typeConv)
+      : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
+                             typeConv) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto xferOp = cast<ConcreteOp>(op);
+    auto adaptor = getTransferOpAdapter(xferOp, operands);
+
+    if (xferOp.getVectorType().getRank() > 1 ||
+        llvm::size(xferOp.indices()) == 0)
+      return failure();
+
+    if (!AffineMap::isMinorIdentity(xferOp.permutation_map()))
+      return failure();
+
+    // Have it handled in vector->llvm conversion pass.
+    if (!xferOp.isMaskedDim(0))
+      return failure();
+
+    auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
+    LLVM::LLVMType vecTy =
+        toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
+    unsigned vecWidth = vecTy.getVectorNumElements();
+    Location loc = op->getLoc();
+
+    // The backend result vector scalarization have trouble scalarize
+    // <1 x ty> result, exclude the x1 width from the lowering.
+    if (vecWidth != 2 && vecWidth != 4)
+      return failure();
+
+    // Obtain dataPtr and elementType from the memref.
+    MemRefType memRefType = xferOp.getMemRefType();
+    // MUBUF instruction operate only on addresspace 0(unified) or 1(global)
+    // In case of 3(LDS): fall back to vector->llvm pass
+    // In case of 5(VGPR): wrong
+    if ((memRefType.getMemorySpace() != 0) &&
+        (memRefType.getMemorySpace() != 1))
+      return failure();
+
+    // Note that the dataPtr starts at the offset address specified by
+    // indices, so no need to calculat offset size in bytes again in
+    // the MUBUF instruction.
+    Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
+                               adaptor.indices(), rewriter, getModule());
+
+    // 1. Create and fill a <4 x i32> dwordConfig with:
+    //    1st two elements holding the address of dataPtr.
+    //    3rd element: -1.
+    //    4th element: 0x27000.
+    SmallVector<int32_t, 4> constConfigAttr{0, 0, -1, 0x27000};
+    Type i32Ty = rewriter.getIntegerType(32);
+    VectorType i32Vecx4 = VectorType::get(4, i32Ty);
+    Value constConfig = rewriter.create<LLVM::ConstantOp>(
+        loc, toLLVMTy(i32Vecx4),
+        DenseElementsAttr::get(i32Vecx4, ArrayRef<int32_t>(constConfigAttr)));
+
+    // Treat first two element of <4 x i32> as i64, and save the dataPtr
+    // to it.
+    Type i64Ty = rewriter.getIntegerType(64);
+    Value i64x2Ty = rewriter.create<LLVM::BitcastOp>(
+        loc,
+        LLVM::LLVMType::getVectorTy(
+            toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), 2),
+        constConfig);
+    Value dataPtrAsI64 = rewriter.create<LLVM::PtrToIntOp>(
+        loc, toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), dataPtr);
+    Value zero = createIndexConstant(rewriter, loc, 0);
+    Value dwordConfig = rewriter.create<LLVM::InsertElementOp>(
+        loc,
+        LLVM::LLVMType::getVectorTy(
+            toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), 2),
+        i64x2Ty, dataPtrAsI64, zero);
+    dwordConfig =
+        rewriter.create<LLVM::BitcastOp>(loc, toLLVMTy(i32Vecx4), dwordConfig);
+
+    // 2. Rewrite op as a buffer read or write.
+    Value int1False = rewriter.create<LLVM::ConstantOp>(
+        loc, toLLVMTy(rewriter.getIntegerType(1)),
+        rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
+    Value int32Zero = rewriter.create<LLVM::ConstantOp>(
+        loc, toLLVMTy(i32Ty),
+        rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0));
+    return replaceTransferOpWithMubuf(rewriter, operands, typeConverter, loc,
+                                      xferOp, vecTy, dwordConfig, int32Zero,
+                                      int32Zero, int1False, int1False);
+  }
+};
+} // end anonymous namespace
+
+void mlir::populateVectorToROCDLConversionPatterns(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+  MLIRContext *ctx = converter.getDialect()->getContext();
+  patterns.insert<VectorTransferConversion<TransferReadOp>,
+                  VectorTransferConversion<TransferWriteOp>>(ctx, converter);
+}
+
+namespace {
+struct LowerVectorToROCDLPass
+    : public ConvertVectorToROCDLBase<LowerVectorToROCDLPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void LowerVectorToROCDLPass::runOnOperation() {
+  LLVMTypeConverter converter(&getContext());
+  OwningRewritePatternList patterns;
+
+  populateVectorToROCDLConversionPatterns(converter, patterns);
+  populateStdToLLVMConversionPatterns(converter, patterns);
+
+  LLVMConversionTarget target(getContext());
+  target.addLegalDialect<ROCDL::ROCDLDialect>();
+
+  if (failed(applyPartialConversion(getOperation(), target, patterns,
+                                    &converter))) {
+    signalPassFailure();
+  }
+}
+
+std::unique_ptr<OperationPass<ModuleOp>>
+mlir::createConvertVectorToROCDLPass() {
+  return std::make_unique<LowerVectorToROCDLPass>();
+}

diff  --git a/mlir/test/Conversion/VectorToROCDL/vector-to-rocdl.mlir b/mlir/test/Conversion/VectorToROCDL/vector-to-rocdl.mlir
new file mode 100644
index 000000000000..1113197aa589
--- /dev/null
+++ b/mlir/test/Conversion/VectorToROCDL/vector-to-rocdl.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-opt %s -convert-vector-to-rocdl | FileCheck %s
+
+gpu.module @test_read{
+func @transfer_readx2(%A : memref<?xf32>, %base: index) -> vector<2xf32> {
+  %f0 = constant 0.0: f32
+  %f = vector.transfer_read %A[%base], %f0
+      {permutation_map = affine_map<(d0) -> (d0)>} :
+    memref<?xf32>, vector<2xf32>
+  return %f: vector<2xf32>
+}
+// CHECK-LABEL: @transfer_readx2
+// CHECK: rocdl.buffer.load {{.*}} !llvm<"<2 x float>">
+
+func @transfer_readx4(%A : memref<?xf32>, %base: index) -> vector<4xf32> {
+  %f0 = constant 0.0: f32
+  %f = vector.transfer_read %A[%base], %f0
+      {permutation_map = affine_map<(d0) -> (d0)>} :
+    memref<?xf32>, vector<4xf32>
+  return %f: vector<4xf32>
+}
+// CHECK-LABEL: @transfer_readx4
+// CHECK: rocdl.buffer.load {{.*}} !llvm<"<4 x float>">
+
+func @transfer_read_dwordConfig(%A : memref<?xf32>, %base: index) -> vector<4xf32> {
+  %f0 = constant 0.0: f32
+  %f = vector.transfer_read %A[%base], %f0
+      {permutation_map = affine_map<(d0) -> (d0)>} :
+    memref<?xf32>, vector<4xf32>
+  return %f: vector<4xf32>
+}
+// CHECK-LABEL: @transfer_read_dwordConfig
+// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}}
+// CHECK: [0, 0, -1, 159744]
+// CHECK: %[[i64:.*]] = llvm.ptrtoint %[[gep]]
+// CHECK: llvm.insertelement %[[i64]]
+}
+
+gpu.module @test_write{
+func @transfer_writex2(%A : memref<?xf32>, %B : vector<2xf32>, %base: index) {
+  vector.transfer_write %B, %A[%base]
+      {permutation_map = affine_map<(d0) -> (d0)>} :
+    vector<2xf32>, memref<?xf32>
+  return
+}
+// CHECK-LABEL: @transfer_writex2
+// CHECK: rocdl.buffer.store {{.*}} !llvm<"<2 x float>">
+
+func @transfer_writex4(%A : memref<?xf32>, %B : vector<4xf32>, %base: index) {
+  vector.transfer_write %B, %A[%base]
+      {permutation_map = affine_map<(d0) -> (d0)>} :
+    vector<4xf32>, memref<?xf32>
+  return
+}
+// CHECK-LABEL: @transfer_writex4
+// CHECK: rocdl.buffer.store {{.*}} !llvm<"<4 x float>">
+
+func @transfer_write_dwordConfig(%A : memref<?xf32>, %B : vector<2xf32>, %base: index) {
+  vector.transfer_write %B, %A[%base]
+      {permutation_map = affine_map<(d0) -> (d0)>} :
+    vector<2xf32>, memref<?xf32>
+  return
+}
+// CHECK-LABEL: @transfer_write_dwordConfig
+// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}}
+// CHECK: [0, 0, -1, 159744]
+// CHECK: %[[i64:.*]] = llvm.ptrtoint %[[gep]]
+// CHECK: llvm.insertelement %[[i64]]
+}

diff  --git a/mlir/test/mlir-rocm-runner/vector-transferops.mlir b/mlir/test/mlir-rocm-runner/vector-transferops.mlir
new file mode 100644
index 000000000000..b028f91f8394
--- /dev/null
+++ b/mlir/test/mlir-rocm-runner/vector-transferops.mlir
@@ -0,0 +1,76 @@
+// RUN: mlir-rocm-runner %s --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s
+
+func @vectransferx2(%arg0 : memref<?xf32>, %arg1 : memref<?xf32>) {
+  %cst = constant 1 : index
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst, %grid_z = %cst)
+             threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst, %block_z = %cst) {
+    %f0 = constant 0.0: f32
+    %base = constant 0 : index
+    %f = vector.transfer_read %arg0[%base], %f0
+        {permutation_map = affine_map<(d0) -> (d0)>} :
+      memref<?xf32>, vector<2xf32>
+
+    %c = addf %f, %f : vector<2xf32>
+
+    %base1 = constant 1 : index
+    vector.transfer_write %c, %arg1[%base1]
+        {permutation_map = affine_map<(d0) -> (d0)>} :
+      vector<2xf32>, memref<?xf32>
+
+    gpu.terminator
+  }
+  return
+}
+
+func @vectransferx4(%arg0 : memref<?xf32>, %arg1 : memref<?xf32>) {
+  %cst = constant 1 : index
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst, %grid_z = %cst)
+             threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst, %block_z = %cst) {
+    %f0 = constant 0.0: f32
+    %base = constant 0 : index
+    %f = vector.transfer_read %arg0[%base], %f0
+        {permutation_map = affine_map<(d0) -> (d0)>} :
+      memref<?xf32>, vector<4xf32>
+
+    %c = addf %f, %f : vector<4xf32>
+
+    vector.transfer_write %c, %arg1[%base]
+        {permutation_map = affine_map<(d0) -> (d0)>} :
+      vector<4xf32>, memref<?xf32>
+
+    gpu.terminator
+  }
+  return
+}
+
+func @main() {
+  %cf1 = constant 1.0 : f32
+
+  %arg0 = alloc() : memref<4xf32>
+  %arg1 = alloc() : memref<4xf32>
+
+  %22 = memref_cast %arg0 : memref<4xf32> to memref<?xf32>
+  %23 = memref_cast %arg1 : memref<4xf32> to memref<?xf32>
+
+  %cast0 = memref_cast %22 : memref<?xf32> to memref<*xf32>
+  %cast1 = memref_cast %23 : memref<?xf32> to memref<*xf32>
+
+  call @mgpuMemHostRegisterFloat(%cast0) : (memref<*xf32>) -> ()
+  call @mgpuMemHostRegisterFloat(%cast1) : (memref<*xf32>) -> ()
+
+  %24 = call @mgpuMemGetDeviceMemRef1dFloat(%22) : (memref<?xf32>) -> (memref<?xf32>)
+  %26 = call @mgpuMemGetDeviceMemRef1dFloat(%23) : (memref<?xf32>) -> (memref<?xf32>)
+
+  // CHECK: [1.23, 2.46, 2.46, 1.23]
+  call @vectransferx2(%24, %26) : (memref<?xf32>,  memref<?xf32>) -> ()
+  call @print_memref_f32(%cast1) : (memref<*xf32>) -> ()
+
+  // CHECK: [2.46, 2.46, 2.46, 2.46]
+  call @vectransferx4(%24, %26) : (memref<?xf32>,  memref<?xf32>) -> ()
+  call @print_memref_f32(%cast1) : (memref<*xf32>) -> ()
+  return
+}
+
+func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
+func @mgpuMemGetDeviceMemRef1dFloat(%ptr : memref<?xf32>) -> (memref<?xf32>)
+func @print_memref_f32(%ptr : memref<*xf32>)


        


More information about the Mlir-commits mailing list