[Mlir-commits] [mlir] 6bd88bb - [MLIR][NVVM] Add Ops for tcgen05 cp and shift (#127798)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 21 02:25:11 PST 2025


Author: Durgadoss R
Date: 2025-02-21T15:55:08+05:30
New Revision: 6bd88bb3ac82c4c4bd11b70cd01a2000c08db32d

URL: https://github.com/llvm/llvm-project/commit/6bd88bb3ac82c4c4bd11b70cd01a2000c08db32d
DIFF: https://github.com/llvm/llvm-project/commit/6bd88bb3ac82c4c4bd11b70cd01a2000c08db32d.diff

LOG: [MLIR][NVVM] Add Ops for tcgen05 cp and shift (#127798)

PR #127669 adds intrinsics for tcgen05.cp/shift.
This PR adds NVVM Dialect Ops for the same.

lit tests are added to verify the lowering
to the intrinsics.

Signed-off-by: Durgadoss R <durgadossr at nvidia.com>

Added: 
    mlir/test/Target/LLVMIR/nvvm/tcgen05-cp.mlir
    mlir/test/Target/LLVMIR/nvvm/tcgen05-shift.mlir

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0de5a87e72c3f..0692e8e32dbf8 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2810,6 +2810,114 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit"> {
   }];
 }
 
+def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift"> {
+  let summary = "Tcgen05 shift operation";
+  let description = [{
+    The `tcgen05.shift` is an asynchronous instruction which initiates
+    the shifting of 32-byte elements downwards across all the rows,
+    except the last, by one row. The operand `taddr` specifies the base
+    address of the matrix in Tensor Memory whose rows must be down shifted.
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-shift)
+  }];
+
+  let arguments = (ins LLVM_PointerTensor:$taddr,
+    DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group);
+
+  let assemblyFormat = "$taddr attr-dict `:` type(operands)";
+
+  string llvmBuilder = [{
+    auto id = ($group == NVVM::Tcgen05GroupKind::CTA_1) ?
+      llvm::Intrinsic::nvvm_tcgen05_shift_down_cg1 :
+      llvm::Intrinsic::nvvm_tcgen05_shift_down_cg2;
+    createIntrinsicCall(builder, id, {$taddr});
+  }];
+}
+
+def Shape128x256b : I32EnumAttrCase<"SHAPE_128x256b", 0, "shape_128x256b">;
+def Shape4x256b   : I32EnumAttrCase<"SHAPE_4x256b",   1, "shape_4x256b">;
+def Shape128x128b : I32EnumAttrCase<"SHAPE_128x128b", 2, "shape_128x128b">;
+def Shape64x128b  : I32EnumAttrCase<"SHAPE_64x128b",  3, "shape_64x128b">;
+def Shape32x128b  : I32EnumAttrCase<"SHAPE_32x128b",  4, "shape_32x128b">;
+
+def Tcgen05CpShape : I32EnumAttr<"Tcgen05CpShape", "tcgen05 cp shapes",
+  [Shape128x256b, Shape4x256b, Shape128x128b, Shape64x128b, Shape32x128b]> {
+    let cppNamespace = "::mlir::NVVM";
+    let genSpecializedAttr = 0;
+}
+def Tcgen05CpShapeAttr : EnumAttr<NVVM_Dialect, Tcgen05CpShape, "tcgen05_cp_shape"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def Tcgen05CpMulticastNone: I32EnumAttrCase<"NONE", 0, "none">;
+def Tcgen05CpMulticastWarpx2_02_13: I32EnumAttrCase<"WARPX2_02_13", 1, "warpx2_02_13">;
+def Tcgen05CpMulticastWarpx2_01_23: I32EnumAttrCase<"WARPX2_01_23", 2, "warpx2_01_23">;
+def Tcgen05CpMulticastWarpx4: I32EnumAttrCase<"WARPX4", 3, "warpx4">;
+
+def Tcgen05CpMulticast : I32EnumAttr<"Tcgen05CpMulticast", "tcgen05 cp multicast",
+  [Tcgen05CpMulticastNone, Tcgen05CpMulticastWarpx2_02_13,
+   Tcgen05CpMulticastWarpx2_01_23, Tcgen05CpMulticastWarpx4]> {
+    let cppNamespace = "::mlir::NVVM";
+    let genSpecializedAttr = 0;
+}
+def Tcgen05CpMulticastAttr : EnumAttr<NVVM_Dialect, Tcgen05CpMulticast, "tcgen05_cp_multicast"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def FormatB6x16_P32: I32EnumAttrCase<"B6x16_P32", 0, "b6x16_p32">;
+def FormatB4x16_P64: I32EnumAttrCase<"B4x16_P64", 1, "b4x16_p64">;
+
+def Tcgen05CpSrcFormat : I32EnumAttr<"Tcgen05CpSrcFormat", "tcgen05 cp source format",
+  [FormatB6x16_P32, FormatB4x16_P64]> {
+    let cppNamespace = "::mlir::NVVM";
+    let genSpecializedAttr = 0;
+}
+def Tcgen05CpSrcFormatAttr : EnumAttr<NVVM_Dialect, Tcgen05CpSrcFormat, "tcgen05_cp_src_fmt"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp"> {
+  let summary = "Tcgen05 copy operation";
+  let description = [{
+    Instruction tcgen05.cp initiates an asynchronous copy operation from
+    shared memory to the location specified by the address operand `taddr`
+    in the Tensor Memory. The 64-bit register operand `smem_desc` specifies
+    the matrix descriptor representing the source matrix in the shared memory
+    that needs to be copied.
+
+    Example:
+    ```mlir
+      nvvm.tcgen05.cp %taddr, %smem_desc {
+        group = #nvvm.tcgen05_group<cta_2>,
+        shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+        multicast = #nvvm.tcgen05_cp_multicast<warpx2_01_23>,
+        srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
+      }
+    ```
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensorcore-5th-generation-instructions-tcgen05-cp)
+  }];
+
+  let arguments = (ins
+    Tcgen05CpShapeAttr:$shape,
+    DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group,
+    DefaultValuedAttr<Tcgen05CpMulticastAttr, "Tcgen05CpMulticast::NONE">:$multicast,
+    OptionalAttr<Tcgen05CpSrcFormatAttr>:$srcFormat,
+    LLVM_PointerTensor:$taddr,
+    I64:$smem_desc);
+
+  let assemblyFormat = "$taddr`,` $smem_desc attr-dict";
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    static llvm::Intrinsic::ID getIntrinsicID(Operation &op);
+  }];
+
+  string llvmBuilder = [{
+    auto id = NVVM::Tcgen05CpOp::getIntrinsicID(*op);
+    createIntrinsicCall(builder, id, {$taddr, $smem_desc});
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM target attribute.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 62f0c21338111..54a09f81a3cd3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -75,6 +75,10 @@ ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
 
 void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
 
+//===----------------------------------------------------------------------===//
+// Verifier methods
+//===----------------------------------------------------------------------===//
+
 // This verifier is shared among the following Ops:
 // CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
 // CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
@@ -1107,6 +1111,38 @@ LogicalResult NVVM::BarrierOp::verify() {
   return success();
 }
 
+LogicalResult NVVM::Tcgen05CpOp::verify() {
+  auto mc = getMulticast();
+
+  using SH = Tcgen05CpShape;
+  using MC = Tcgen05CpMulticast;
+  switch (getShape()) {
+  case SH::SHAPE_128x256b:
+  case SH::SHAPE_128x128b:
+  case SH::SHAPE_4x256b:
+    if (mc != MC::NONE)
+      return emitError("Invalid multicast type for tcgen05.cp Op");
+    break;
+  case SH::SHAPE_64x128b:
+    if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
+      return emitError("Shape 64x128b requires multicast warpx2_01_23 or "
+                       "warpx2_02_13 for tcgen05.cp Op");
+    break;
+  case SH::SHAPE_32x128b:
+    if (mc != MC::WARPX4)
+      return emitError(
+          "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
+    break;
+  default:
+    return emitError("Invalid shape for tcgen05.cp Op");
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// getIntrinsicID/getIntrinsicIDAndArgs methods
+//===----------------------------------------------------------------------===//
+
 #define CP_ASYNC_ID_IMPL(mod, size, suffix)                                    \
   llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
 
@@ -1314,6 +1350,46 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
   return id;
 }
 
+#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg)                                 \
+  llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
+
+#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta)                            \
+  is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2)                           \
+          : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
+
+#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)                          \
+  [&]() -> auto {                                                              \
+    if (src_fmt == Tcgen05CpSrcFormat::B6x16_P32)                              \
+      return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta);                   \
+    if (src_fmt == Tcgen05CpSrcFormat::B4x16_P64)                              \
+      return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta);                   \
+    return TCGEN05_CP_2CTA(shape_mc, , is_2cta);                               \
+  }()
+
+llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
+  auto curOp = cast<NVVM::Tcgen05CpOp>(op);
+  bool is2CTA = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
+  auto srcFmt = curOp.getSrcFormat();
+  auto mc = curOp.getMulticast();
+
+  switch (curOp.getShape()) {
+  case Tcgen05CpShape::SHAPE_128x256b:
+    return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
+  case Tcgen05CpShape::SHAPE_128x128b:
+    return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
+  case Tcgen05CpShape::SHAPE_4x256b:
+    return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
+  case Tcgen05CpShape::SHAPE_32x128b:
+    return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
+  case Tcgen05CpShape::SHAPE_64x128b:
+    return (mc == Tcgen05CpMulticast::WARPX2_01_23)
+               ? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
+               : GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
+  default:
+    llvm_unreachable("Invalid shape in tcgen05 cp Op");
+  }
+}
+
 /// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
 /// have ConstantRangeAttr.
 static void nvvmInferResultRanges(Operation *op, Value result,

diff  --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-cp.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-cp.mlir
new file mode 100644
index 0000000000000..91128cd00c873
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-cp.mlir
@@ -0,0 +1,136 @@
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @nvvm_tcgen05_cp_128x256b
+llvm.func @nvvm_tcgen05_cp_128x256b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+  nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>}
+
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+  nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>, group = #nvvm.tcgen05_group<cta_2>}
+
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_128x256b>,
+    group = #nvvm.tcgen05_group<cta_2>,
+    srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
+  }
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_128x256b>,
+    group = #nvvm.tcgen05_group<cta_2>,
+    srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
+  }
+  llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_cp_4x256b
+llvm.func @nvvm_tcgen05_cp_4x256b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+  nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_4x256b>}
+
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+  nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_4x256b>, group = #nvvm.tcgen05_group<cta_2>}
+
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_4x256b>,
+    group = #nvvm.tcgen05_group<cta_2>,
+    srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
+  }
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_4x256b>,
+    group = #nvvm.tcgen05_group<cta_2>,
+    srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
+  }
+  llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_cp_128x128b
+llvm.func @nvvm_tcgen05_cp_128x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+  nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x128b>}
+
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+  nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x128b>, group = #nvvm.tcgen05_group<cta_2>}
+
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_128x128b>,
+    group = #nvvm.tcgen05_group<cta_2>,
+    srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
+  }
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_128x128b>,
+    group = #nvvm.tcgen05_group<cta_2>,
+    srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
+  }
+  llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_cp_64x128b
+llvm.func @nvvm_tcgen05_cp_64x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+    multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>
+  }
+
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+    group = #nvvm.tcgen05_group<cta_2>,
+    multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>
+  }
+
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.b4x16_p64.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+    group = #nvvm.tcgen05_group<cta_1>,
+    multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>,
+    srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
+  }
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_01_23.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+    group = #nvvm.tcgen05_group<cta_2>,
+    multicast = #nvvm.tcgen05_cp_multicast<warpx2_01_23>,
+    srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
+  }
+
+  llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_cp_32x128b
+llvm.func @nvvm_tcgen05_cp_32x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
+    multicast = #nvvm.tcgen05_cp_multicast<warpx4>
+  }
+
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
+    group = #nvvm.tcgen05_group<cta_2>,
+    multicast = #nvvm.tcgen05_cp_multicast<warpx4>
+  }
+
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
+    group = #nvvm.tcgen05_group<cta_2>,
+    multicast = #nvvm.tcgen05_cp_multicast<warpx4>,
+    srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
+  }
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.b6x16_p32.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
+    group = #nvvm.tcgen05_group<cta_1>,
+    multicast = #nvvm.tcgen05_cp_multicast<warpx4>,
+    srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
+  }
+
+  llvm.return
+}

diff  --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-shift.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-shift.mlir
new file mode 100644
index 0000000000000..48753a3fdb21b
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-shift.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @llvm_nvvm_tcgen05_shift
+llvm.func @llvm_nvvm_tcgen05_shift(%taddr : !llvm.ptr<6>) {
+  // CHECK: call void @llvm.nvvm.tcgen05.shift.down.cg1(ptr addrspace(6) %{{.*}})
+  nvvm.tcgen05.shift %taddr : !llvm.ptr<6>
+
+  // CHECK: call void @llvm.nvvm.tcgen05.shift.down.cg2(ptr addrspace(6) %{{.*}})
+  nvvm.tcgen05.shift %taddr {group = #nvvm.tcgen05_group<cta_2>} : !llvm.ptr<6>
+  llvm.return
+}

diff  --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 8957377607dad..4fca7fd801dbe 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -122,3 +122,33 @@ llvm.func @convert_float_to_tf32_no_rnd_mode(%src : f32) -> i32 {
   %res = nvvm.cvt.float.to.tf32 %src
   llvm.return %res : i32
 }
+
+// -----
+
+llvm.func @nvvm_tcgen05_cp_128x256b_mc(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+  // expected-error @below {{Invalid multicast type for tcgen05.cp Op}}
+  nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>, multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>}
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_tcgen05_cp_32x128b_wx2(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+  // expected-error @below {{Shape 32x128b requires multicast warpx4 for tcgen05.cp Op}}
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
+    multicast = #nvvm.tcgen05_cp_multicast<warpx2_01_23>
+  }
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_tcgen05_cp_64x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+  // expected-error @below {{Shape 64x128b requires multicast warpx2_01_23 or warpx2_02_13 for tcgen05.cp Op}}
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+    multicast = #nvvm.tcgen05_cp_multicast<warpx4>
+  }
+  llvm.return
+}


        


More information about the Mlir-commits mailing list