[Mlir-commits] [mlir] 3666362 - [mlir][nvvm] Introduce redux op

Guray Ozen llvmlistbot at llvm.org
Fri Jan 20 03:14:31 PST 2023


Author: Guray Ozen
Date: 2023-01-20T12:14:24+01:00
New Revision: 36663626ee336905745cb1c259b3b65c9ff656bf

URL: https://github.com/llvm/llvm-project/commit/36663626ee336905745cb1c259b3b65c9ff656bf
DIFF: https://github.com/llvm/llvm-project/commit/36663626ee336905745cb1c259b3b65c9ff656bf.diff

LOG: [mlir][nvvm] Introduce redux op

Ptx model has `redux.sync` that performs reduction operation on the data from each predicated active thread in the thread group. It only is available sm80+.

This revision adds redux as on op to nvvm dialect.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D142088

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
    mlir/test/Dialect/LLVMIR/nvvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 8a16d6430423..289064cbed5b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -135,6 +135,45 @@ def NVVM_RcpApproxFtzF32Op : NVVM_IntrOp<"rcp.approx.ftz.f", [Pure], 1> {
   let assemblyFormat = "$arg attr-dict `:` type($res)";
 }
 
+//===----------------------------------------------------------------------===//
+// NVVM redux op definitions
+//===----------------------------------------------------------------------===//
+
+def ReduxKindNone : I32EnumAttrCase<"NONE", 0, "none">;
+def ReduxKindAdd  : I32EnumAttrCase<"ADD", 1, "add">;
+def ReduxKindAnd  : I32EnumAttrCase<"AND", 2, "and">;
+def ReduxKindMax  : I32EnumAttrCase<"MAX", 3, "max">;
+def ReduxKindMin  : I32EnumAttrCase<"MIN", 4, "min">;
+def ReduxKindOr   : I32EnumAttrCase<"OR", 5, "or">;
+def ReduxKindUmax : I32EnumAttrCase<"UMAX", 6, "umax">;
+def ReduxKindUmin : I32EnumAttrCase<"UMIN", 7, "umin">;
+def ReduxKindXor  : I32EnumAttrCase<"XOR", 8, "xor">; 
+
+/// Enum attribute of the 
diff erent kinds.
+def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind",
+  [ReduxKindAdd, ReduxKindAnd, ReduxKindMax, ReduxKindMin, ReduxKindOr, 
+    ReduxKindUmax, ReduxKindUmin, ReduxKindXor]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+
+def ReduxKindAttr : EnumAttr<NVVM_Dialect, ReduxKind, "redux_kind">;
+
+def NVVM_ReduxOp :
+  NVVM_Op<"redux.sync">,
+  Results<(outs LLVM_Type:$res)>,
+  Arguments<(ins LLVM_Type:$val,
+                 ReduxKindAttr:$kind,
+                 I32:$mask_and_clamp)> {
+  string llvmBuilder = [{
+      auto intId = getReduxIntrinsicId($_resultType, $kind);
+      $res = createIntrinsicCall(builder, intId, {$val, $mask_and_clamp});
+  }];
+  let assemblyFormat = [{
+    $kind $val `,` $mask_and_clamp  attr-dict `:` type($val) `->` type($res)
+   }];   
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM synchronization op definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index feaf5ca3f563..d7f1bb6a7be7 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -25,6 +25,32 @@ using namespace mlir;
 using namespace mlir::LLVM;
 using mlir::LLVM::detail::createIntrinsicCall;
 
+static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
+                                               NVVM::ReduxKind kind) {
+  if (!resultType->isIntegerTy(32))
+    llvm_unreachable("unsupported data type for redux");
+
+  switch (kind) {
+  case NVVM::ReduxKind::ADD:
+    return llvm::Intrinsic::nvvm_redux_sync_add;
+  case NVVM::ReduxKind::UMAX:
+    return llvm::Intrinsic::nvvm_redux_sync_umax;
+  case NVVM::ReduxKind::UMIN:
+    return llvm::Intrinsic::nvvm_redux_sync_umin;
+  case NVVM::ReduxKind::AND:
+    return llvm::Intrinsic::nvvm_redux_sync_and;
+  case NVVM::ReduxKind::OR:
+    return llvm::Intrinsic::nvvm_redux_sync_or;
+  case NVVM::ReduxKind::XOR:
+    return llvm::Intrinsic::nvvm_redux_sync_xor;
+  case NVVM::ReduxKind::MAX:
+    return llvm::Intrinsic::nvvm_redux_sync_max;
+  case NVVM::ReduxKind::MIN:
+    return llvm::Intrinsic::nvvm_redux_sync_min;
+  }
+  llvm_unreachable("unknown redux kind");
+}
+
 static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
                                               NVVM::ShflKind kind,
                                               bool withPredicate) {

diff  --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 150f3085df6d..2e3b20b1d661 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -310,6 +310,29 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<i32, 3>) {
   %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
   llvm.return
 }
+
+// CHECK-LABEL: llvm.func @redux_sync
+llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 {  
+  // CHECK: nvvm.redux.sync  add %{{.*}}
+  %r1 = nvvm.redux.sync add %value, %offset : i32 -> i32
+  // CHECK: nvvm.redux.sync  max %{{.*}}
+  %r2 = nvvm.redux.sync max %value, %offset : i32 -> i32
+  // CHECK: nvvm.redux.sync  min %{{.*}}
+  %r3 = nvvm.redux.sync min %value, %offset : i32 -> i32
+  // CHECK: nvvm.redux.sync  umax %{{.*}}
+  %r5 = nvvm.redux.sync umax %value, %offset : i32 -> i32
+  // CHECK: nvvm.redux.sync  umin %{{.*}}
+  %r6 = nvvm.redux.sync umin %value, %offset : i32 -> i32
+  // CHECK: nvvm.redux.sync  and %{{.*}}
+  %r7 = nvvm.redux.sync and %value, %offset : i32 -> i32  
+  // CHECK: nvvm.redux.sync  or %{{.*}}
+  %r8 = nvvm.redux.sync or %value, %offset : i32 -> i32  
+  // CHECK: nvvm.redux.sync  xor %{{.*}}
+  %r9 = nvvm.redux.sync xor %value, %offset : i32 -> i32
+  llvm.return %r1 : i32
+}
+
+
 // -----
 
 // expected-error at below {{attribute attached to unexpected op}}


        


More information about the Mlir-commits mailing list