[Mlir-commits] [mlir] 3666362 - [mlir][nvvm] Introduce redux op
Guray Ozen
llvmlistbot at llvm.org
Fri Jan 20 03:14:31 PST 2023
Author: Guray Ozen
Date: 2023-01-20T12:14:24+01:00
New Revision: 36663626ee336905745cb1c259b3b65c9ff656bf
URL: https://github.com/llvm/llvm-project/commit/36663626ee336905745cb1c259b3b65c9ff656bf
DIFF: https://github.com/llvm/llvm-project/commit/36663626ee336905745cb1c259b3b65c9ff656bf.diff
LOG: [mlir][nvvm] Introduce redux op
Ptx model has `redux.sync` that performs reduction operation on the data from each predicated active thread in the thread group. It only is available sm80+.
This revision adds redux as on op to nvvm dialect.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D142088
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
mlir/test/Dialect/LLVMIR/nvvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 8a16d6430423..289064cbed5b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -135,6 +135,45 @@ def NVVM_RcpApproxFtzF32Op : NVVM_IntrOp<"rcp.approx.ftz.f", [Pure], 1> {
let assemblyFormat = "$arg attr-dict `:` type($res)";
}
+//===----------------------------------------------------------------------===//
+// NVVM redux op definitions
+//===----------------------------------------------------------------------===//
+
+def ReduxKindNone : I32EnumAttrCase<"NONE", 0, "none">;
+def ReduxKindAdd : I32EnumAttrCase<"ADD", 1, "add">;
+def ReduxKindAnd : I32EnumAttrCase<"AND", 2, "and">;
+def ReduxKindMax : I32EnumAttrCase<"MAX", 3, "max">;
+def ReduxKindMin : I32EnumAttrCase<"MIN", 4, "min">;
+def ReduxKindOr : I32EnumAttrCase<"OR", 5, "or">;
+def ReduxKindUmax : I32EnumAttrCase<"UMAX", 6, "umax">;
+def ReduxKindUmin : I32EnumAttrCase<"UMIN", 7, "umin">;
+def ReduxKindXor : I32EnumAttrCase<"XOR", 8, "xor">;
+
+/// Enum attribute of the
diff erent kinds.
+def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind",
+ [ReduxKindAdd, ReduxKindAnd, ReduxKindMax, ReduxKindMin, ReduxKindOr,
+ ReduxKindUmax, ReduxKindUmin, ReduxKindXor]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+
+def ReduxKindAttr : EnumAttr<NVVM_Dialect, ReduxKind, "redux_kind">;
+
+def NVVM_ReduxOp :
+ NVVM_Op<"redux.sync">,
+ Results<(outs LLVM_Type:$res)>,
+ Arguments<(ins LLVM_Type:$val,
+ ReduxKindAttr:$kind,
+ I32:$mask_and_clamp)> {
+ string llvmBuilder = [{
+ auto intId = getReduxIntrinsicId($_resultType, $kind);
+ $res = createIntrinsicCall(builder, intId, {$val, $mask_and_clamp});
+ }];
+ let assemblyFormat = [{
+ $kind $val `,` $mask_and_clamp attr-dict `:` type($val) `->` type($res)
+ }];
+}
+
//===----------------------------------------------------------------------===//
// NVVM synchronization op definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index feaf5ca3f563..d7f1bb6a7be7 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -25,6 +25,32 @@ using namespace mlir;
using namespace mlir::LLVM;
using mlir::LLVM::detail::createIntrinsicCall;
+static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
+ NVVM::ReduxKind kind) {
+ if (!resultType->isIntegerTy(32))
+ llvm_unreachable("unsupported data type for redux");
+
+ switch (kind) {
+ case NVVM::ReduxKind::ADD:
+ return llvm::Intrinsic::nvvm_redux_sync_add;
+ case NVVM::ReduxKind::UMAX:
+ return llvm::Intrinsic::nvvm_redux_sync_umax;
+ case NVVM::ReduxKind::UMIN:
+ return llvm::Intrinsic::nvvm_redux_sync_umin;
+ case NVVM::ReduxKind::AND:
+ return llvm::Intrinsic::nvvm_redux_sync_and;
+ case NVVM::ReduxKind::OR:
+ return llvm::Intrinsic::nvvm_redux_sync_or;
+ case NVVM::ReduxKind::XOR:
+ return llvm::Intrinsic::nvvm_redux_sync_xor;
+ case NVVM::ReduxKind::MAX:
+ return llvm::Intrinsic::nvvm_redux_sync_max;
+ case NVVM::ReduxKind::MIN:
+ return llvm::Intrinsic::nvvm_redux_sync_min;
+ }
+ llvm_unreachable("unknown redux kind");
+}
+
static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
NVVM::ShflKind kind,
bool withPredicate) {
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 150f3085df6d..2e3b20b1d661 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -310,6 +310,29 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<i32, 3>) {
%l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
llvm.return
}
+
+// CHECK-LABEL: llvm.func @redux_sync
+llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 {
+ // CHECK: nvvm.redux.sync add %{{.*}}
+ %r1 = nvvm.redux.sync add %value, %offset : i32 -> i32
+ // CHECK: nvvm.redux.sync max %{{.*}}
+ %r2 = nvvm.redux.sync max %value, %offset : i32 -> i32
+ // CHECK: nvvm.redux.sync min %{{.*}}
+ %r3 = nvvm.redux.sync min %value, %offset : i32 -> i32
+ // CHECK: nvvm.redux.sync umax %{{.*}}
+ %r5 = nvvm.redux.sync umax %value, %offset : i32 -> i32
+ // CHECK: nvvm.redux.sync umin %{{.*}}
+ %r6 = nvvm.redux.sync umin %value, %offset : i32 -> i32
+ // CHECK: nvvm.redux.sync and %{{.*}}
+ %r7 = nvvm.redux.sync and %value, %offset : i32 -> i32
+ // CHECK: nvvm.redux.sync or %{{.*}}
+ %r8 = nvvm.redux.sync or %value, %offset : i32 -> i32
+ // CHECK: nvvm.redux.sync xor %{{.*}}
+ %r9 = nvvm.redux.sync xor %value, %offset : i32 -> i32
+ llvm.return %r1 : i32
+}
+
+
// -----
// expected-error at below {{attribute attached to unexpected op}}
More information about the Mlir-commits
mailing list