[Mlir-commits] [mlir] 2f33f11 - [mlir][NVVM] Add ldmatrix op to NVVM dialect

Thomas Raoux llvmlistbot at llvm.org
Thu Mar 10 12:38:24 PST 2022


Author: Thomas Raoux
Date: 2022-03-10T20:37:17Z
New Revision: 2f33f11428c1832a413d5ca617948ac5cc397385

URL: https://github.com/llvm/llvm-project/commit/2f33f11428c1832a413d5ca617948ac5cc397385
DIFF: https://github.com/llvm/llvm-project/commit/2f33f11428c1832a413d5ca617948ac5cc397385.diff

LOG: [mlir][NVVM] Add ldmatrix op to NVVM dialect

Differential Revision: https://reviews.llvm.org/D121347

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Dialect/LLVMIR/nvvm.mlir
    mlir/test/Target/LLVMIR/nvvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 196fbdd951053..d65b525eacf6f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -634,4 +634,48 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
   let hasVerifier = 1;
 }
 
+def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
+  Results<(outs AnyType:$res)>,
+  Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> {
+
+  let summary = "cooperative matrix load";
+
+  string llvmBuilder = [{
+      auto operands = moduleTranslation.lookupValues(opInst.getOperands());
+      auto intId = getLdMatrixIntrinsicId($layout, $num);
+      $res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()});
+  }];
+
+  string baseDescription = [{
+    The `nvvm.ldmatrix` operation collectively loads one or more matrices across
+    all threads in a warp from the location indicated by the address operand
+    `ptr` from shared memory.
+
+    The attribute `num` indicates how many 8x8 16-bit matrices are to be loaded.
+
+    All the threads in the warp must execute the same ldmatrix operations.
+
+    Each row of 8 elements needs to be consecutive in memory. Each lane of the
+    warp contains the start address of a row of 8 elements laid out as below:
+
+    ```
+    num | lane 0--7    | Threads 8--15  | Threads 16--31
+    1   | addr0--addr7 |                |
+    2   | addr0--addr7 | addr8--addr15  |
+    4   | addr0--addr7 | addr8--addr15  | addr16--addr31
+    ```
+
+    Example:
+    ```mlir
+    %l1 = nvvm.ldmatrix %ptr {num = 1 : i32, layout = #nvvm.mma_layout<row>} :
+      (!llvm.ptr<i32, 3>) -> i32
+    %l2 = nvvm.ldmatrix %ptr {num = 4 : i32, layout = #nvvm.mma_layout<row>} :
+      (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
+    ```
+  }];
+
+  let assemblyFormat = "$ptr attr-dict `:` functional-type($ptr, $res)";
+  let hasVerifier = 1;
+}
+
 #endif // NVVMIR_OPS

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 3557525eb0fe0..5256be06d6d5a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -219,6 +219,28 @@ LogicalResult NVVM::WMMAMmaOp::verify() {
   return success();
 }
 
+LogicalResult NVVM::LdMatrixOp::verify() {
+  unsigned addressSpace =
+      ptr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
+  if (addressSpace != 3)
+    return emitOpError("expected source pointer in memory space 3");
+
+  if (num() != 1 && num() != 2 && num() != 4)
+    return emitOpError("expected num attribute to be 1, 2 or 4");
+
+  Type i32 = IntegerType::get(getContext(), 32);
+  if (num() == 1 && getType() != i32)
+    return emitOpError("expected destination type is i32");
+  if (num() == 2 || num() == 4) {
+    Type dstType = LLVM::LLVMStructType::getLiteral(
+        getContext(), SmallVector<Type>(num(), i32));
+    if (getType() != dstType)
+      return emitOpError("expected destination type is a structure of ")
+             << num() << " elements of type i32";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // NVVMDialect initialization, type parsing, and registration.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 4b5ca8fa86b35..88f8af0eef136 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -64,6 +64,35 @@ static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
   llvm_unreachable("unknown shuffle kind");
 }
 
+/// Return the intrinsic ID associated with ldmatrix for the given paramters.
+static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
+                                                  int32_t num) {
+  if (layout == NVVM::MMALayout::col) {
+    switch (num) {
+    case 1:
+      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
+    case 2:
+      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
+    case 4:
+      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
+    default:
+      llvm_unreachable("unsupported number of matrix");
+    }
+
+  } else {
+    switch (num) {
+    case 1:
+      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
+    case 2:
+      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
+    case 4:
+      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
+    default:
+      llvm_unreachable("unsupported number of matrix");
+    }
+  }
+}
+
 namespace {
 /// Implementation of the dialect interface that converts operations belonging
 /// to the NVVM dialect to LLVM IR.

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index c64efeededf59..94c3446821edc 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1191,6 +1191,38 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<
 
 // -----
 
+llvm.func @wmmald_matrix(%arg0: !llvm.ptr<i32>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}}
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32>) -> i32
+  llvm.return
+}
+
+// -----
+
+llvm.func @wmmald_matrix(%arg0: !llvm.ptr<i32, 3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4}}
+  %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> i32
+  llvm.return
+}
+
+// -----
+
+llvm.func @wmmald_matrix(%arg0: !llvm.ptr<i32, 3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is i32}}
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32)>
+  llvm.return
+}
+
+// -----
+
+llvm.func @wmmald_matrix(%arg0: !llvm.ptr<i32, 3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}}
+  %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32)>
+  llvm.return
+}
+
+// -----
+
 llvm.func @caller() {
   // expected-error @below {{expected function call to produce a value}}
   llvm.call @callee() : () -> ()

diff  --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index db57b1e34f28c..2b191541c3b02 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -105,6 +105,16 @@ llvm.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
   llvm.return
 }
 
+// CHECK-LABEL: llvm.func @ld_matrix
+llvm.func @ld_matrix(%arg0: !llvm.ptr<i32, 3>) {
+  // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} : (!llvm.ptr<i32, 3>) -> i32
+  %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> i32
+  // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 2 : i32} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32)>
+  %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32)>
+  // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  llvm.return
+}
 // -----
 
 // expected-error at below {{attribute attached to unexpected op}}

diff  --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index b66533040e32a..ef7a1f9410598 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -176,6 +176,17 @@ llvm.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
   llvm.return
 }
 
+// CHECK-LABEL: @ld_matrix(
+llvm.func @ld_matrix(%arg0: !llvm.ptr<i32, 3>) {
+  // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3i32(i32 addrspace(3)* %{{.*}})
+  %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> i32
+  // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3i32(i32 addrspace(3)* %{{.*}})
+  %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32)>
+  // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3i32(i32 addrspace(3)* %{{.*}})
+  %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  llvm.return
+}
+
 // This function has the "kernel" attribute attached and should appear in the
 // NVVM annotations after conversion.
 llvm.func @kernel_func() attributes {nvvm.kernel} {


        


More information about the Mlir-commits mailing list