[Mlir-commits] [mlir] 6bd88bb - [MLIR][NVVM] Add Ops for tcgen05 cp and shift (#127798)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 21 02:25:11 PST 2025
Author: Durgadoss R
Date: 2025-02-21T15:55:08+05:30
New Revision: 6bd88bb3ac82c4c4bd11b70cd01a2000c08db32d
URL: https://github.com/llvm/llvm-project/commit/6bd88bb3ac82c4c4bd11b70cd01a2000c08db32d
DIFF: https://github.com/llvm/llvm-project/commit/6bd88bb3ac82c4c4bd11b70cd01a2000c08db32d.diff
LOG: [MLIR][NVVM] Add Ops for tcgen05 cp and shift (#127798)
PR #127669 adds intrinsics for tcgen05.cp/shift.
This PR adds NVVM Dialect Ops for the same.
lit tests are added to verify the lowering
to the intrinsics.
Signed-off-by: Durgadoss R <durgadossr at nvidia.com>
Added:
mlir/test/Target/LLVMIR/nvvm/tcgen05-cp.mlir
mlir/test/Target/LLVMIR/nvvm/tcgen05-shift.mlir
Modified:
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0de5a87e72c3f..0692e8e32dbf8 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2810,6 +2810,114 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit"> {
}];
}
+def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift"> {
+ let summary = "Tcgen05 shift operation";
+ let description = [{
+ The `tcgen05.shift` is an asynchronous instruction which initiates
+ the shifting of 32-byte elements downwards across all the rows,
+ except the last, by one row. The operand `taddr` specifies the base
+ address of the matrix in Tensor Memory whose rows must be down shifted.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-shift)
+ }];
+
+ let arguments = (ins LLVM_PointerTensor:$taddr,
+ DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group);
+
+ let assemblyFormat = "$taddr attr-dict `:` type(operands)";
+
+ string llvmBuilder = [{
+ auto id = ($group == NVVM::Tcgen05GroupKind::CTA_1) ?
+ llvm::Intrinsic::nvvm_tcgen05_shift_down_cg1 :
+ llvm::Intrinsic::nvvm_tcgen05_shift_down_cg2;
+ createIntrinsicCall(builder, id, {$taddr});
+ }];
+}
+
+def Shape128x256b : I32EnumAttrCase<"SHAPE_128x256b", 0, "shape_128x256b">;
+def Shape4x256b : I32EnumAttrCase<"SHAPE_4x256b", 1, "shape_4x256b">;
+def Shape128x128b : I32EnumAttrCase<"SHAPE_128x128b", 2, "shape_128x128b">;
+def Shape64x128b : I32EnumAttrCase<"SHAPE_64x128b", 3, "shape_64x128b">;
+def Shape32x128b : I32EnumAttrCase<"SHAPE_32x128b", 4, "shape_32x128b">;
+
+def Tcgen05CpShape : I32EnumAttr<"Tcgen05CpShape", "tcgen05 cp shapes",
+ [Shape128x256b, Shape4x256b, Shape128x128b, Shape64x128b, Shape32x128b]> {
+ let cppNamespace = "::mlir::NVVM";
+ let genSpecializedAttr = 0;
+}
+def Tcgen05CpShapeAttr : EnumAttr<NVVM_Dialect, Tcgen05CpShape, "tcgen05_cp_shape"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def Tcgen05CpMulticastNone: I32EnumAttrCase<"NONE", 0, "none">;
+def Tcgen05CpMulticastWarpx2_02_13: I32EnumAttrCase<"WARPX2_02_13", 1, "warpx2_02_13">;
+def Tcgen05CpMulticastWarpx2_01_23: I32EnumAttrCase<"WARPX2_01_23", 2, "warpx2_01_23">;
+def Tcgen05CpMulticastWarpx4: I32EnumAttrCase<"WARPX4", 3, "warpx4">;
+
+def Tcgen05CpMulticast : I32EnumAttr<"Tcgen05CpMulticast", "tcgen05 cp multicast",
+ [Tcgen05CpMulticastNone, Tcgen05CpMulticastWarpx2_02_13,
+ Tcgen05CpMulticastWarpx2_01_23, Tcgen05CpMulticastWarpx4]> {
+ let cppNamespace = "::mlir::NVVM";
+ let genSpecializedAttr = 0;
+}
+def Tcgen05CpMulticastAttr : EnumAttr<NVVM_Dialect, Tcgen05CpMulticast, "tcgen05_cp_multicast"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def FormatB6x16_P32: I32EnumAttrCase<"B6x16_P32", 0, "b6x16_p32">;
+def FormatB4x16_P64: I32EnumAttrCase<"B4x16_P64", 1, "b4x16_p64">;
+
+def Tcgen05CpSrcFormat : I32EnumAttr<"Tcgen05CpSrcFormat", "tcgen05 cp source format",
+ [FormatB6x16_P32, FormatB4x16_P64]> {
+ let cppNamespace = "::mlir::NVVM";
+ let genSpecializedAttr = 0;
+}
+def Tcgen05CpSrcFormatAttr : EnumAttr<NVVM_Dialect, Tcgen05CpSrcFormat, "tcgen05_cp_src_fmt"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp"> {
+ let summary = "Tcgen05 copy operation";
+ let description = [{
+ Instruction tcgen05.cp initiates an asynchronous copy operation from
+ shared memory to the location specified by the address operand `taddr`
+ in the Tensor Memory. The 64-bit register operand `smem_desc` specifies
+ the matrix descriptor representing the source matrix in the shared memory
+ that needs to be copied.
+
+ Example:
+ ```mlir
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ group = #nvvm.tcgen05_group<cta_2>,
+ shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+ multicast = #nvvm.tcgen05_cp_multicast<warpx2_01_23>,
+ srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
+ }
+ ```
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensorcore-5th-generation-instructions-tcgen05-cp)
+ }];
+
+ let arguments = (ins
+ Tcgen05CpShapeAttr:$shape,
+ DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group,
+ DefaultValuedAttr<Tcgen05CpMulticastAttr, "Tcgen05CpMulticast::NONE">:$multicast,
+ OptionalAttr<Tcgen05CpSrcFormatAttr>:$srcFormat,
+ LLVM_PointerTensor:$taddr,
+ I64:$smem_desc);
+
+ let assemblyFormat = "$taddr`,` $smem_desc attr-dict";
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ static llvm::Intrinsic::ID getIntrinsicID(Operation &op);
+ }];
+
+ string llvmBuilder = [{
+ auto id = NVVM::Tcgen05CpOp::getIntrinsicID(*op);
+ createIntrinsicCall(builder, id, {$taddr, $smem_desc});
+ }];
+}
+
//===----------------------------------------------------------------------===//
// NVVM target attribute.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 62f0c21338111..54a09f81a3cd3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -75,6 +75,10 @@ ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
+//===----------------------------------------------------------------------===//
+// Verifier methods
+//===----------------------------------------------------------------------===//
+
// This verifier is shared among the following Ops:
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
// CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
@@ -1107,6 +1111,38 @@ LogicalResult NVVM::BarrierOp::verify() {
return success();
}
+LogicalResult NVVM::Tcgen05CpOp::verify() {
+ auto mc = getMulticast();
+
+ using SH = Tcgen05CpShape;
+ using MC = Tcgen05CpMulticast;
+ switch (getShape()) {
+ case SH::SHAPE_128x256b:
+ case SH::SHAPE_128x128b:
+ case SH::SHAPE_4x256b:
+ if (mc != MC::NONE)
+ return emitError("Invalid multicast type for tcgen05.cp Op");
+ break;
+ case SH::SHAPE_64x128b:
+ if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
+ return emitError("Shape 64x128b requires multicast warpx2_01_23 or "
+ "warpx2_02_13 for tcgen05.cp Op");
+ break;
+ case SH::SHAPE_32x128b:
+ if (mc != MC::WARPX4)
+ return emitError(
+ "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
+ break;
+ default:
+ return emitError("Invalid shape for tcgen05.cp Op");
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// getIntrinsicID/getIntrinsicIDAndArgs methods
+//===----------------------------------------------------------------------===//
+
#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
@@ -1314,6 +1350,46 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
return id;
}
+#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
+ llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
+
+#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
+ is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
+ : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
+
+#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
+ [&]() -> auto { \
+ if (src_fmt == Tcgen05CpSrcFormat::B6x16_P32) \
+ return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
+ if (src_fmt == Tcgen05CpSrcFormat::B4x16_P64) \
+ return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
+ return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
+ }()
+
+llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
+ auto curOp = cast<NVVM::Tcgen05CpOp>(op);
+ bool is2CTA = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
+ auto srcFmt = curOp.getSrcFormat();
+ auto mc = curOp.getMulticast();
+
+ switch (curOp.getShape()) {
+ case Tcgen05CpShape::SHAPE_128x256b:
+ return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
+ case Tcgen05CpShape::SHAPE_128x128b:
+ return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
+ case Tcgen05CpShape::SHAPE_4x256b:
+ return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
+ case Tcgen05CpShape::SHAPE_32x128b:
+ return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
+ case Tcgen05CpShape::SHAPE_64x128b:
+ return (mc == Tcgen05CpMulticast::WARPX2_01_23)
+ ? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
+ : GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
+ default:
+ llvm_unreachable("Invalid shape in tcgen05 cp Op");
+ }
+}
+
/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
/// have ConstantRangeAttr.
static void nvvmInferResultRanges(Operation *op, Value result,
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-cp.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-cp.mlir
new file mode 100644
index 0000000000000..91128cd00c873
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-cp.mlir
@@ -0,0 +1,136 @@
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @nvvm_tcgen05_cp_128x256b
+llvm.func @nvvm_tcgen05_cp_128x256b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+ nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>}
+
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+ nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>, group = #nvvm.tcgen05_group<cta_2>}
+
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_128x256b>,
+ group = #nvvm.tcgen05_group<cta_2>,
+ srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
+ }
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_128x256b>,
+ group = #nvvm.tcgen05_group<cta_2>,
+ srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
+ }
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_cp_4x256b
+llvm.func @nvvm_tcgen05_cp_4x256b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+ nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_4x256b>}
+
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+ nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_4x256b>, group = #nvvm.tcgen05_group<cta_2>}
+
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_4x256b>,
+ group = #nvvm.tcgen05_group<cta_2>,
+ srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
+ }
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_4x256b>,
+ group = #nvvm.tcgen05_group<cta_2>,
+ srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
+ }
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_cp_128x128b
+llvm.func @nvvm_tcgen05_cp_128x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+ nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x128b>}
+
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+ nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x128b>, group = #nvvm.tcgen05_group<cta_2>}
+
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_128x128b>,
+ group = #nvvm.tcgen05_group<cta_2>,
+ srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
+ }
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_128x128b>,
+ group = #nvvm.tcgen05_group<cta_2>,
+ srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
+ }
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_cp_64x128b
+llvm.func @nvvm_tcgen05_cp_64x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+ multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>
+ }
+
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+ group = #nvvm.tcgen05_group<cta_2>,
+ multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>
+ }
+
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.b4x16_p64.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+ group = #nvvm.tcgen05_group<cta_1>,
+ multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>,
+ srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
+ }
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_01_23.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+ group = #nvvm.tcgen05_group<cta_2>,
+ multicast = #nvvm.tcgen05_cp_multicast<warpx2_01_23>,
+ srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
+ }
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_cp_32x128b
+llvm.func @nvvm_tcgen05_cp_32x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
+ multicast = #nvvm.tcgen05_cp_multicast<warpx4>
+ }
+
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
+ group = #nvvm.tcgen05_group<cta_2>,
+ multicast = #nvvm.tcgen05_cp_multicast<warpx4>
+ }
+
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
+ group = #nvvm.tcgen05_group<cta_2>,
+ multicast = #nvvm.tcgen05_cp_multicast<warpx4>,
+ srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
+ }
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.b6x16_p32.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
+ group = #nvvm.tcgen05_group<cta_1>,
+ multicast = #nvvm.tcgen05_cp_multicast<warpx4>,
+ srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
+ }
+
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-shift.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-shift.mlir
new file mode 100644
index 0000000000000..48753a3fdb21b
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-shift.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @llvm_nvvm_tcgen05_shift
+llvm.func @llvm_nvvm_tcgen05_shift(%taddr : !llvm.ptr<6>) {
+ // CHECK: call void @llvm.nvvm.tcgen05.shift.down.cg1(ptr addrspace(6) %{{.*}})
+ nvvm.tcgen05.shift %taddr : !llvm.ptr<6>
+
+ // CHECK: call void @llvm.nvvm.tcgen05.shift.down.cg2(ptr addrspace(6) %{{.*}})
+ nvvm.tcgen05.shift %taddr {group = #nvvm.tcgen05_group<cta_2>} : !llvm.ptr<6>
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 8957377607dad..4fca7fd801dbe 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -122,3 +122,33 @@ llvm.func @convert_float_to_tf32_no_rnd_mode(%src : f32) -> i32 {
%res = nvvm.cvt.float.to.tf32 %src
llvm.return %res : i32
}
+
+// -----
+
+llvm.func @nvvm_tcgen05_cp_128x256b_mc(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+ // expected-error @below {{Invalid multicast type for tcgen05.cp Op}}
+ nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>, multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>}
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_tcgen05_cp_32x128b_wx2(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+ // expected-error @below {{Shape 32x128b requires multicast warpx4 for tcgen05.cp Op}}
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
+ multicast = #nvvm.tcgen05_cp_multicast<warpx2_01_23>
+ }
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_tcgen05_cp_64x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+ // expected-error @below {{Shape 64x128b requires multicast warpx2_01_23 or warpx2_02_13 for tcgen05.cp Op}}
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+ multicast = #nvvm.tcgen05_cp_multicast<warpx4>
+ }
+ llvm.return
+}
More information about the Mlir-commits
mailing list