[Mlir-commits] [mlir] 47555d7 - [mlir][gpu] Extend shuffle op modes and add nvvm lowering

Thomas Raoux llvmlistbot at llvm.org
Fri Nov 19 11:14:50 PST 2021


Author: Thomas Raoux
Date: 2021-11-19T11:14:31-08:00
New Revision: 47555d73f6538cf2c092a7314e3c82c631ce4ccb

URL: https://github.com/llvm/llvm-project/commit/47555d73f6538cf2c092a7314e3c82c631ce4ccb
DIFF: https://github.com/llvm/llvm-project/commit/47555d73f6538cf2c092a7314e3c82c631ce4ccb.diff

LOG: [mlir][gpu] Extend shuffle op modes and add nvvm lowering

Add up, down and idx modes to gpu shuffle ops, also change the mode from
string to enum

Differential Revision: https://reviews.llvm.org/D114188

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/GPUOps.td
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
    mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
    mlir/test/Dialect/GPU/ops.mlir
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Dialect/LLVMIR/nvvm.mlir
    mlir/test/Target/LLVMIR/nvvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index adefba30d9a8e..71c13f878b658 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -647,13 +647,21 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
 }
 
 def GPU_ShuffleOpXor : StrEnumAttrCase<"XOR", -1, "xor">;
+def GPU_ShuffleOpDown : StrEnumAttrCase<"DOWN", -1, "down">;
+def GPU_ShuffleOpUp : StrEnumAttrCase<"UP", -1, "up">;
+def GPU_ShuffleOpIdx : StrEnumAttrCase<"IDX", -1, "idx">;
 
 def GPU_ShuffleModeAttr : StrEnumAttr<"ShuffleModeAttr",
     "Indexing modes supported by gpu.shuffle.",
     [
-      GPU_ShuffleOpXor,
+      GPU_ShuffleOpXor, GPU_ShuffleOpUp, GPU_ShuffleOpDown, GPU_ShuffleOpIdx,
     ]>{
   let cppNamespace = "::mlir::gpu";
+  let storageType = "mlir::StringAttr";
+  let returnType = "::mlir::gpu::ShuffleModeAttr";
+  let convertFromStorage =
+    "*symbolizeEnum<::mlir::gpu::ShuffleModeAttr>($_self.getValue())";
+  let constBuilderCall = "$_builder.getStringAttr(stringifyEnum($0))";
 }
 
 

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index db76cc1a93c32..b1182eae3237f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -97,22 +97,36 @@ def NVVM_Barrier0Op : NVVM_Op<"barrier0"> {
   let assemblyFormat = "attr-dict";
 }
 
-def NVVM_ShflBflyOp :
-  NVVM_Op<"shfl.sync.bfly">,
+def ShflKindBfly : StrEnumAttrCase<"bfly">;
+def ShflKindUp : StrEnumAttrCase<"up">;
+def ShflKindDown : StrEnumAttrCase<"down">;
+def ShflKindIdx : StrEnumAttrCase<"idx">;
+
+/// Enum attribute of the 
diff erent shuffle kinds.
+def ShflKind : StrEnumAttr<"ShflKind", "NVVM shuffle kind",
+  [ShflKindBfly, ShflKindUp, ShflKindDown, ShflKindIdx]> {
+  let cppNamespace = "::mlir::NVVM";
+  let storageType = "mlir::StringAttr";
+  let returnType = "NVVM::ShflKind";
+  let convertFromStorage = "*symbolizeEnum<NVVM::ShflKind>($_self.getValue())";
+  let constBuilderCall = "$_builder.getStringAttr(stringifyEnum($0))";
+}
+
+def NVVM_ShflOp :
+  NVVM_Op<"shfl.sync">,
   Results<(outs LLVM_Type:$res)>,
-  Arguments<(ins LLVM_Type:$dst,
+  Arguments<(ins I32:$dst,
                  LLVM_Type:$val,
-                 LLVM_Type:$offset,
-                 LLVM_Type:$mask_and_clamp,
+                 I32:$offset,
+                 I32:$mask_and_clamp,
+                 ShflKind:$kind,
                  OptionalAttr<UnitAttr>:$return_value_and_is_valid)> {
   string llvmBuilder = [{
-      auto intId = getShflBflyIntrinsicId(
-          $_resultType, static_cast<bool>($return_value_and_is_valid));
+      auto intId = getShflIntrinsicId(
+          $_resultType, $kind, static_cast<bool>($return_value_and_is_valid));
       $res = createIntrinsicCall(builder,
           intId, {$dst, $val, $offset, $mask_and_clamp});
   }];
-  let parser = [{ return parseNVVMShflSyncBflyOp(parser, result); }];
-  let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
   let verifier = [{
     if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
       return success();
@@ -125,6 +139,10 @@ def NVVM_ShflBflyOp :
                        "i1 as the second element");
     return success();
   }];
+  let assemblyFormat = [{
+    $kind $dst `,` $val `,` $offset `,` $mask_and_clamp  attr-dict
+     `:` type($val) `->` type($res)
+   }];
 }
 
 def NVVM_VoteBallotOp :

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 5cec21ec328fa..ac43d1a98791a 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -39,6 +39,21 @@ using namespace mlir;
 
 namespace {
 
+/// Convert gpu dialect shfl mode enum to the equivalent nvvm one.
+static NVVM::ShflKind convertShflKind(gpu::ShuffleModeAttr mode) {
+  switch (mode) {
+  case gpu::ShuffleModeAttr::XOR:
+    return NVVM::ShflKind::bfly;
+  case gpu::ShuffleModeAttr::UP:
+    return NVVM::ShflKind::up;
+  case gpu::ShuffleModeAttr::DOWN:
+    return NVVM::ShflKind::down;
+  case gpu::ShuffleModeAttr::IDX:
+    return NVVM::ShflKind::idx;
+  }
+  llvm_unreachable("unknown shuffle mode");
+}
+
 struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
   using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
 
@@ -81,9 +96,9 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
         rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.width(), one);
 
     auto returnValueAndIsValidAttr = rewriter.getUnitAttr();
-    Value shfl = rewriter.create<NVVM::ShflBflyOp>(
+    Value shfl = rewriter.create<NVVM::ShflOp>(
         loc, resultTy, activeMask, adaptor.value(), adaptor.offset(),
-        maskAndClamp, returnValueAndIsValidAttr);
+        maskAndClamp, convertShflKind(op.mode()), returnValueAndIsValidAttr);
     Value shflValue = rewriter.create<LLVM::ExtractValueOp>(
         loc, valueTy, shfl, rewriter.getIndexArrayAttr(0));
     Value isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>(

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index aa276a42882b9..dc1949f012bc1 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -302,7 +302,7 @@ static LogicalResult verifyShuffleOp(gpu::ShuffleOp shuffleOp) {
 }
 
 static void printShuffleOp(OpAsmPrinter &p, ShuffleOp op) {
-  p << ' ' << op.getOperands() << ' ' << op.mode() << " : "
+  p << ' ' << op.getOperands() << ' ' << stringifyEnum(op.mode()) << " : "
     << op.value().getType();
 }
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 8d47a81b917cb..bbfca3137c017 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -43,33 +43,6 @@ static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
     p << " : " << op->getResultTypes();
 }
 
-// <operation> ::=
-//     `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask`
-//      ({return_value_and_is_valid})? : result_type
-static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
-                                           OperationState &result) {
-  SmallVector<OpAsmParser::OperandType, 8> ops;
-  Type resultType;
-  if (parser.parseOperandList(ops) ||
-      parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseColonType(resultType) ||
-      parser.addTypeToList(resultType, result.types))
-    return failure();
-
-  for (auto &attr : result.attributes) {
-    if (attr.getName() != "return_value_and_is_valid")
-      continue;
-    auto structType = resultType.dyn_cast<LLVM::LLVMStructType>();
-    if (structType && !structType.getBody().empty())
-      resultType = structType.getBody()[0];
-    break;
-  }
-
-  auto int32Ty = IntegerType::get(parser.getContext(), 32);
-  return parser.resolveOperands(ops, {int32Ty, resultType, int32Ty, int32Ty},
-                                parser.getNameLoc(), result.operands);
-}
-
 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
 static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
                                          OperationState &result) {

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index f64d389bf9b20..a45cd7a7d4309 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -23,15 +23,45 @@ using namespace mlir;
 using namespace mlir::LLVM;
 using mlir::LLVM::detail::createIntrinsicCall;
 
-static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType,
-                                                  bool withPredicate) {
+static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
+                                              NVVM::ShflKind kind,
+                                              bool withPredicate) {
+
   if (withPredicate) {
     resultType = cast<llvm::StructType>(resultType)->getElementType(0);
-    return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
-                                   : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
+    switch (kind) {
+    case NVVM::ShflKind::bfly:
+      return resultType->isFloatTy()
+                 ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
+                 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
+    case NVVM::ShflKind::up:
+      return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
+                                     : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
+    case NVVM::ShflKind::down:
+      return resultType->isFloatTy()
+                 ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
+                 : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
+    case NVVM::ShflKind::idx:
+      return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
+                                     : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
+    }
+  } else {
+    switch (kind) {
+    case NVVM::ShflKind::bfly:
+      return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
+                                     : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
+    case NVVM::ShflKind::up:
+      return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
+                                     : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
+    case NVVM::ShflKind::down:
+      return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
+                                     : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
+    case NVVM::ShflKind::idx:
+      return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
+                                     : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
+    }
   }
-  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
-                                 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
+  llvm_unreachable("unknown shuffle kind");
 }
 
 namespace {

diff  --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 53a8f9a3e1848..f2f81739aabe9 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -78,7 +78,7 @@ gpu.module @test_module {
   gpu.func @gpu_all_reduce_op() {
     %arg0 = arith.constant 1.0 : f32
     // TODO: Check full IR expansion once lowering has settled.
-    // CHECK: nvvm.shfl.sync.bfly
+    // CHECK: nvvm.shfl.sync "bfly" {{.*}}
     // CHECK: nvvm.barrier0
     // CHECK: llvm.fadd
     %result = "gpu.all_reduce"(%arg0) ({}) {op = "add"} : (f32) -> (f32)
@@ -94,7 +94,7 @@ gpu.module @test_module {
   gpu.func @gpu_all_reduce_region() {
     %arg0 = arith.constant 1 : i32
     // TODO: Check full IR expansion once lowering has settled.
-    // CHECK: nvvm.shfl.sync.bfly
+    // CHECK: nvvm.shfl.sync "bfly" {{.*}}
     // CHECK: nvvm.barrier0
     %result = "gpu.all_reduce"(%arg0) ({
     ^bb(%lhs : i32, %rhs : i32):
@@ -109,7 +109,7 @@ gpu.module @test_module {
 
 gpu.module @test_module {
   // CHECK-LABEL: func @gpu_shuffle()
-  builtin.func @gpu_shuffle() -> (f32) {
+  builtin.func @gpu_shuffle() -> (f32, f32, f32, f32) {
     // CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
     %arg0 = arith.constant 1.0 : f32
     // CHECK: %[[#OFFSET:]] = llvm.mlir.constant(4 : i32) : i32
@@ -120,12 +120,18 @@ gpu.module @test_module {
     // CHECK: %[[#SHL:]] = llvm.shl %[[#ONE]], %[[#WIDTH]] : i32
     // CHECK: %[[#MASK:]] = llvm.sub %[[#SHL]], %[[#ONE]] : i32
     // CHECK: %[[#CLAMP:]] = llvm.sub %[[#WIDTH]], %[[#ONE]] : i32
-    // CHECK: %[[#SHFL:]] = nvvm.shfl.sync.bfly %[[#MASK]], %[[#VALUE]], %[[#OFFSET]], %[[#CLAMP]] : !llvm.struct<(f32, i1)>
+    // CHECK: %[[#SHFL:]] = nvvm.shfl.sync "bfly" %[[#MASK]], %[[#VALUE]], %[[#OFFSET]], %[[#CLAMP]] {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
     // CHECK: llvm.extractvalue %[[#SHFL]][0 : index] : !llvm.struct<(f32, i1)>
     // CHECK: llvm.extractvalue %[[#SHFL]][1 : index] : !llvm.struct<(f32, i1)>
     %shfl, %pred = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "xor" } : (f32, i32, i32) -> (f32, i1)
-
-    std.return %shfl : f32
+    // CHECK: nvvm.shfl.sync "up" {{.*}} {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
+    %shflu, %predu = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "up" } : (f32, i32, i32) -> (f32, i1)
+    // CHECK: nvvm.shfl.sync "down" {{.*}} {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
+    %shfld, %predd = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "down" } : (f32, i32, i32) -> (f32, i1)
+    // CHECK: nvvm.shfl.sync "idx" {{.*}} {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
+    %shfli, %predi = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "idx" } : (f32, i32, i32) -> (f32, i1)
+
+    std.return %shfl, %shflu, %shfld, %shfli : f32, f32,f32, f32
   }
 }
 

diff  --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index c24fd7bf8a818..1c5ab143dcc44 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -55,6 +55,12 @@ module attributes {gpu.container_module} {
       %offset = arith.constant 3 : i32
       // CHECK: gpu.shuffle %{{.*}}, %{{.*}}, %{{.*}} xor : f32
       %shfl, %pred = gpu.shuffle %arg0, %offset, %width xor : f32
+      // CHECK: gpu.shuffle %{{.*}}, %{{.*}}, %{{.*}} up : f32
+      %shfl1, %pred1 = gpu.shuffle %arg0, %offset, %width up : f32
+      // CHECK: gpu.shuffle %{{.*}}, %{{.*}}, %{{.*}} down : f32
+      %shfl2, %pred2 = gpu.shuffle %arg0, %offset, %width down : f32
+      // CHECK: gpu.shuffle %{{.*}}, %{{.*}}, %{{.*}} idx : f32
+      %shfl3, %pred3 = gpu.shuffle %arg0, %offset, %width idx : f32
 
       "gpu.barrier"() : () -> ()
 

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index fd9b5765fa2f2..9b9df36210a79 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -495,21 +495,21 @@ func @null_non_llvm_type() {
 
 func @nvvm_invalid_shfl_pred_1(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
   // expected-error at +1 {{expected return type to be a two-element struct with i1 as the second element}}
-  %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32
+  %0 = nvvm.shfl.sync "bfly" %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> i32
 }
 
 // -----
 
 func @nvvm_invalid_shfl_pred_2(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
   // expected-error at +1 {{expected return type to be a two-element struct with i1 as the second element}}
-  %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.struct<(i32)>
+  %0 = nvvm.shfl.sync "bfly" %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32)>
 }
 
 // -----
 
 func @nvvm_invalid_shfl_pred_3(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
   // expected-error at +1 {{expected return type to be a two-element struct with i1 as the second element}}
-  %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.struct<(i32, i32)>
+  %0 = nvvm.shfl.sync "bfly" %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i32)>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 638a5ab47dd0d..e2ca4d71bbe25 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -37,20 +37,26 @@ func @llvm.nvvm.barrier0() {
 func @nvvm_shfl(
     %arg0 : i32, %arg1 : i32, %arg2 : i32,
     %arg3 : i32, %arg4 : f32) -> i32 {
-  // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i32
-  %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 : i32
-  // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32
-  %1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 : f32
+  // CHECK: nvvm.shfl.sync "bfly" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i32 -> i32
+  %0 = nvvm.shfl.sync "bfly" %arg0, %arg3, %arg1, %arg2 : i32 -> i32
+  // CHECK: nvvm.shfl.sync "bfly" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 -> f32
+  %1 = nvvm.shfl.sync "bfly" %arg0, %arg4, %arg1, %arg2 : f32 -> f32
+  // CHECK: nvvm.shfl.sync "up" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 -> f32
+  %2 = nvvm.shfl.sync "up" %arg0, %arg4, %arg1, %arg2 : f32 -> f32
+  // CHECK: nvvm.shfl.sync "down" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 -> f32
+  %3 = nvvm.shfl.sync "down" %arg0, %arg4, %arg1, %arg2 : f32 -> f32
+  // CHECK: nvvm.shfl.sync "idx" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 -> f32
+  %4 = nvvm.shfl.sync "idx" %arg0, %arg4, %arg1, %arg2 : f32 -> f32
   llvm.return %0 : i32
 }
 
 func @nvvm_shfl_pred(
     %arg0 : i32, %arg1 : i32, %arg2 : i32,
     %arg3 : i32, %arg4 : f32) -> !llvm.struct<(i32, i1)> {
-  // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.struct<(i32, i1)>
-  %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.struct<(i32, i1)>
-  // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.struct<(f32, i1)>
-  %1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 {return_value_and_is_valid} : !llvm.struct<(f32, i1)>
+  // CHECK: nvvm.shfl.sync "bfly" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
+  %0 = nvvm.shfl.sync "bfly" %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
+  // CHECK: nvvm.shfl.sync "bfly" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
+  %1 = nvvm.shfl.sync "bfly" %arg0, %arg4, %arg1, %arg2 {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
   llvm.return %0 : !llvm.struct<(i32, i1)>
 }
 

diff  --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 08859b22c0d83..a9ec2259dfcff 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -42,9 +42,21 @@ llvm.func @nvvm_shfl(
     %0 : i32, %1 : i32, %2 : i32,
     %3 : i32, %4 : f32) -> i32 {
   // CHECK: call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
-  %6 = nvvm.shfl.sync.bfly %0, %3, %1, %2 : i32
+  %6 = nvvm.shfl.sync "bfly" %0, %3, %1, %2 : i32 -> i32
   // CHECK: call float @llvm.nvvm.shfl.sync.bfly.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
-  %7 = nvvm.shfl.sync.bfly %0, %4, %1, %2 : f32
+  %7 = nvvm.shfl.sync "bfly" %0, %4, %1, %2 : f32 -> f32
+  // CHECK: call i32 @llvm.nvvm.shfl.sync.up.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %8 = nvvm.shfl.sync "up" %0, %3, %1, %2 : i32 -> i32
+  // CHECK: call float @llvm.nvvm.shfl.sync.up.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %9 = nvvm.shfl.sync "up" %0, %4, %1, %2 : f32 -> f32
+  // CHECK: call i32 @llvm.nvvm.shfl.sync.down.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %10 = nvvm.shfl.sync "down" %0, %3, %1, %2 : i32 -> i32
+  // CHECK: call float @llvm.nvvm.shfl.sync.down.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %11 = nvvm.shfl.sync "down" %0, %4, %1, %2 : f32 -> f32
+  // CHECK: call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %12 = nvvm.shfl.sync "idx" %0, %3, %1, %2 : i32 -> i32
+  // CHECK: call float @llvm.nvvm.shfl.sync.idx.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %13 = nvvm.shfl.sync "idx" %0, %4, %1, %2 : f32 -> f32
   llvm.return %6 : i32
 }
 
@@ -52,9 +64,21 @@ llvm.func @nvvm_shfl_pred(
     %0 : i32, %1 : i32, %2 : i32,
     %3 : i32, %4 : f32) -> !llvm.struct<(i32, i1)> {
   // CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.bfly.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
-  %6 = nvvm.shfl.sync.bfly %0, %3, %1, %2 {return_value_and_is_valid} : !llvm.struct<(i32, i1)>
+  %6 = nvvm.shfl.sync "bfly" %0, %3, %1, %2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
   // CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.bfly.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
-  %7 = nvvm.shfl.sync.bfly %0, %4, %1, %2 {return_value_and_is_valid} : !llvm.struct<(f32, i1)>
+  %7 = nvvm.shfl.sync "bfly" %0, %4, %1, %2 {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
+  // CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.up.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %8 = nvvm.shfl.sync "up" %0, %3, %1, %2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
+  // CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.up.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %9 = nvvm.shfl.sync "up" %0, %4, %1, %2 {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
+  // CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.down.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %10 = nvvm.shfl.sync "down" %0, %3, %1, %2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
+  // CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.down.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %11 = nvvm.shfl.sync "down" %0, %4, %1, %2 {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
+  // CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.idx.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %12 = nvvm.shfl.sync "idx" %0, %3, %1, %2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
+  // CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.idx.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %13 = nvvm.shfl.sync "idx" %0, %4, %1, %2 {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
   llvm.return %6 : !llvm.struct<(i32, i1)>
 }
 


        


More information about the Mlir-commits mailing list