[Mlir-commits] [mlir] 47555d7 - [mlir][gpu] Extend shuffle op modes and add nvvm lowering
Thomas Raoux
llvmlistbot at llvm.org
Fri Nov 19 11:14:50 PST 2021
Author: Thomas Raoux
Date: 2021-11-19T11:14:31-08:00
New Revision: 47555d73f6538cf2c092a7314e3c82c631ce4ccb
URL: https://github.com/llvm/llvm-project/commit/47555d73f6538cf2c092a7314e3c82c631ce4ccb
DIFF: https://github.com/llvm/llvm-project/commit/47555d73f6538cf2c092a7314e3c82c631ce4ccb.diff
LOG: [mlir][gpu] Extend shuffle op modes and add nvvm lowering
Add up, down and idx modes to gpu shuffle ops, also change the mode from
string to enum
Differential Revision: https://reviews.llvm.org/D114188
Added:
Modified:
mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
mlir/test/Dialect/GPU/ops.mlir
mlir/test/Dialect/LLVMIR/invalid.mlir
mlir/test/Dialect/LLVMIR/nvvm.mlir
mlir/test/Target/LLVMIR/nvvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index adefba30d9a8e..71c13f878b658 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -647,13 +647,21 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
}
def GPU_ShuffleOpXor : StrEnumAttrCase<"XOR", -1, "xor">;
+def GPU_ShuffleOpDown : StrEnumAttrCase<"DOWN", -1, "down">;
+def GPU_ShuffleOpUp : StrEnumAttrCase<"UP", -1, "up">;
+def GPU_ShuffleOpIdx : StrEnumAttrCase<"IDX", -1, "idx">;
def GPU_ShuffleModeAttr : StrEnumAttr<"ShuffleModeAttr",
"Indexing modes supported by gpu.shuffle.",
[
- GPU_ShuffleOpXor,
+ GPU_ShuffleOpXor, GPU_ShuffleOpUp, GPU_ShuffleOpDown, GPU_ShuffleOpIdx,
]>{
let cppNamespace = "::mlir::gpu";
+ let storageType = "mlir::StringAttr";
+ let returnType = "::mlir::gpu::ShuffleModeAttr";
+ let convertFromStorage =
+ "*symbolizeEnum<::mlir::gpu::ShuffleModeAttr>($_self.getValue())";
+ let constBuilderCall = "$_builder.getStringAttr(stringifyEnum($0))";
}
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index db76cc1a93c32..b1182eae3237f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -97,22 +97,36 @@ def NVVM_Barrier0Op : NVVM_Op<"barrier0"> {
let assemblyFormat = "attr-dict";
}
-def NVVM_ShflBflyOp :
- NVVM_Op<"shfl.sync.bfly">,
+def ShflKindBfly : StrEnumAttrCase<"bfly">;
+def ShflKindUp : StrEnumAttrCase<"up">;
+def ShflKindDown : StrEnumAttrCase<"down">;
+def ShflKindIdx : StrEnumAttrCase<"idx">;
+
+/// Enum attribute of the
diff erent shuffle kinds.
+def ShflKind : StrEnumAttr<"ShflKind", "NVVM shuffle kind",
+ [ShflKindBfly, ShflKindUp, ShflKindDown, ShflKindIdx]> {
+ let cppNamespace = "::mlir::NVVM";
+ let storageType = "mlir::StringAttr";
+ let returnType = "NVVM::ShflKind";
+ let convertFromStorage = "*symbolizeEnum<NVVM::ShflKind>($_self.getValue())";
+ let constBuilderCall = "$_builder.getStringAttr(stringifyEnum($0))";
+}
+
+def NVVM_ShflOp :
+ NVVM_Op<"shfl.sync">,
Results<(outs LLVM_Type:$res)>,
- Arguments<(ins LLVM_Type:$dst,
+ Arguments<(ins I32:$dst,
LLVM_Type:$val,
- LLVM_Type:$offset,
- LLVM_Type:$mask_and_clamp,
+ I32:$offset,
+ I32:$mask_and_clamp,
+ ShflKind:$kind,
OptionalAttr<UnitAttr>:$return_value_and_is_valid)> {
string llvmBuilder = [{
- auto intId = getShflBflyIntrinsicId(
- $_resultType, static_cast<bool>($return_value_and_is_valid));
+ auto intId = getShflIntrinsicId(
+ $_resultType, $kind, static_cast<bool>($return_value_and_is_valid));
$res = createIntrinsicCall(builder,
intId, {$dst, $val, $offset, $mask_and_clamp});
}];
- let parser = [{ return parseNVVMShflSyncBflyOp(parser, result); }];
- let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
let verifier = [{
if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
return success();
@@ -125,6 +139,10 @@ def NVVM_ShflBflyOp :
"i1 as the second element");
return success();
}];
+ let assemblyFormat = [{
+ $kind $dst `,` $val `,` $offset `,` $mask_and_clamp attr-dict
+ `:` type($val) `->` type($res)
+ }];
}
def NVVM_VoteBallotOp :
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 5cec21ec328fa..ac43d1a98791a 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -39,6 +39,21 @@ using namespace mlir;
namespace {
+/// Convert gpu dialect shfl mode enum to the equivalent nvvm one.
+static NVVM::ShflKind convertShflKind(gpu::ShuffleModeAttr mode) {
+ switch (mode) {
+ case gpu::ShuffleModeAttr::XOR:
+ return NVVM::ShflKind::bfly;
+ case gpu::ShuffleModeAttr::UP:
+ return NVVM::ShflKind::up;
+ case gpu::ShuffleModeAttr::DOWN:
+ return NVVM::ShflKind::down;
+ case gpu::ShuffleModeAttr::IDX:
+ return NVVM::ShflKind::idx;
+ }
+ llvm_unreachable("unknown shuffle mode");
+}
+
struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
@@ -81,9 +96,9 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.width(), one);
auto returnValueAndIsValidAttr = rewriter.getUnitAttr();
- Value shfl = rewriter.create<NVVM::ShflBflyOp>(
+ Value shfl = rewriter.create<NVVM::ShflOp>(
loc, resultTy, activeMask, adaptor.value(), adaptor.offset(),
- maskAndClamp, returnValueAndIsValidAttr);
+ maskAndClamp, convertShflKind(op.mode()), returnValueAndIsValidAttr);
Value shflValue = rewriter.create<LLVM::ExtractValueOp>(
loc, valueTy, shfl, rewriter.getIndexArrayAttr(0));
Value isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>(
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index aa276a42882b9..dc1949f012bc1 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -302,7 +302,7 @@ static LogicalResult verifyShuffleOp(gpu::ShuffleOp shuffleOp) {
}
static void printShuffleOp(OpAsmPrinter &p, ShuffleOp op) {
- p << ' ' << op.getOperands() << ' ' << op.mode() << " : "
+ p << ' ' << op.getOperands() << ' ' << stringifyEnum(op.mode()) << " : "
<< op.value().getType();
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 8d47a81b917cb..bbfca3137c017 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -43,33 +43,6 @@ static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
p << " : " << op->getResultTypes();
}
-// <operation> ::=
-// `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask`
-// ({return_value_and_is_valid})? : result_type
-static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
- OperationState &result) {
- SmallVector<OpAsmParser::OperandType, 8> ops;
- Type resultType;
- if (parser.parseOperandList(ops) ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(resultType) ||
- parser.addTypeToList(resultType, result.types))
- return failure();
-
- for (auto &attr : result.attributes) {
- if (attr.getName() != "return_value_and_is_valid")
- continue;
- auto structType = resultType.dyn_cast<LLVM::LLVMStructType>();
- if (structType && !structType.getBody().empty())
- resultType = structType.getBody()[0];
- break;
- }
-
- auto int32Ty = IntegerType::get(parser.getContext(), 32);
- return parser.resolveOperands(ops, {int32Ty, resultType, int32Ty, int32Ty},
- parser.getNameLoc(), result.operands);
-}
-
// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
OperationState &result) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index f64d389bf9b20..a45cd7a7d4309 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -23,15 +23,45 @@ using namespace mlir;
using namespace mlir::LLVM;
using mlir::LLVM::detail::createIntrinsicCall;
-static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType,
- bool withPredicate) {
+static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
+ NVVM::ShflKind kind,
+ bool withPredicate) {
+
if (withPredicate) {
resultType = cast<llvm::StructType>(resultType)->getElementType(0);
- return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
- : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
+ switch (kind) {
+ case NVVM::ShflKind::bfly:
+ return resultType->isFloatTy()
+ ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
+ : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
+ case NVVM::ShflKind::up:
+ return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
+ : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
+ case NVVM::ShflKind::down:
+ return resultType->isFloatTy()
+ ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
+ : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
+ case NVVM::ShflKind::idx:
+ return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
+ : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
+ }
+ } else {
+ switch (kind) {
+ case NVVM::ShflKind::bfly:
+ return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
+ : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
+ case NVVM::ShflKind::up:
+ return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
+ : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
+ case NVVM::ShflKind::down:
+ return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
+ : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
+ case NVVM::ShflKind::idx:
+ return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
+ : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
+ }
}
- return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
- : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
+ llvm_unreachable("unknown shuffle kind");
}
namespace {
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 53a8f9a3e1848..f2f81739aabe9 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -78,7 +78,7 @@ gpu.module @test_module {
gpu.func @gpu_all_reduce_op() {
%arg0 = arith.constant 1.0 : f32
// TODO: Check full IR expansion once lowering has settled.
- // CHECK: nvvm.shfl.sync.bfly
+ // CHECK: nvvm.shfl.sync "bfly" {{.*}}
// CHECK: nvvm.barrier0
// CHECK: llvm.fadd
%result = "gpu.all_reduce"(%arg0) ({}) {op = "add"} : (f32) -> (f32)
@@ -94,7 +94,7 @@ gpu.module @test_module {
gpu.func @gpu_all_reduce_region() {
%arg0 = arith.constant 1 : i32
// TODO: Check full IR expansion once lowering has settled.
- // CHECK: nvvm.shfl.sync.bfly
+ // CHECK: nvvm.shfl.sync "bfly" {{.*}}
// CHECK: nvvm.barrier0
%result = "gpu.all_reduce"(%arg0) ({
^bb(%lhs : i32, %rhs : i32):
@@ -109,7 +109,7 @@ gpu.module @test_module {
gpu.module @test_module {
// CHECK-LABEL: func @gpu_shuffle()
- builtin.func @gpu_shuffle() -> (f32) {
+ builtin.func @gpu_shuffle() -> (f32, f32, f32, f32) {
// CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
%arg0 = arith.constant 1.0 : f32
// CHECK: %[[#OFFSET:]] = llvm.mlir.constant(4 : i32) : i32
@@ -120,12 +120,18 @@ gpu.module @test_module {
// CHECK: %[[#SHL:]] = llvm.shl %[[#ONE]], %[[#WIDTH]] : i32
// CHECK: %[[#MASK:]] = llvm.sub %[[#SHL]], %[[#ONE]] : i32
// CHECK: %[[#CLAMP:]] = llvm.sub %[[#WIDTH]], %[[#ONE]] : i32
- // CHECK: %[[#SHFL:]] = nvvm.shfl.sync.bfly %[[#MASK]], %[[#VALUE]], %[[#OFFSET]], %[[#CLAMP]] : !llvm.struct<(f32, i1)>
+ // CHECK: %[[#SHFL:]] = nvvm.shfl.sync "bfly" %[[#MASK]], %[[#VALUE]], %[[#OFFSET]], %[[#CLAMP]] {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
// CHECK: llvm.extractvalue %[[#SHFL]][0 : index] : !llvm.struct<(f32, i1)>
// CHECK: llvm.extractvalue %[[#SHFL]][1 : index] : !llvm.struct<(f32, i1)>
%shfl, %pred = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "xor" } : (f32, i32, i32) -> (f32, i1)
-
- std.return %shfl : f32
+ // CHECK: nvvm.shfl.sync "up" {{.*}} {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
+ %shflu, %predu = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "up" } : (f32, i32, i32) -> (f32, i1)
+ // CHECK: nvvm.shfl.sync "down" {{.*}} {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
+ %shfld, %predd = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "down" } : (f32, i32, i32) -> (f32, i1)
+ // CHECK: nvvm.shfl.sync "idx" {{.*}} {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
+ %shfli, %predi = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "idx" } : (f32, i32, i32) -> (f32, i1)
+
+ std.return %shfl, %shflu, %shfld, %shfli : f32, f32,f32, f32
}
}
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index c24fd7bf8a818..1c5ab143dcc44 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -55,6 +55,12 @@ module attributes {gpu.container_module} {
%offset = arith.constant 3 : i32
// CHECK: gpu.shuffle %{{.*}}, %{{.*}}, %{{.*}} xor : f32
%shfl, %pred = gpu.shuffle %arg0, %offset, %width xor : f32
+ // CHECK: gpu.shuffle %{{.*}}, %{{.*}}, %{{.*}} up : f32
+ %shfl1, %pred1 = gpu.shuffle %arg0, %offset, %width up : f32
+ // CHECK: gpu.shuffle %{{.*}}, %{{.*}}, %{{.*}} down : f32
+ %shfl2, %pred2 = gpu.shuffle %arg0, %offset, %width down : f32
+ // CHECK: gpu.shuffle %{{.*}}, %{{.*}}, %{{.*}} idx : f32
+ %shfl3, %pred3 = gpu.shuffle %arg0, %offset, %width idx : f32
"gpu.barrier"() : () -> ()
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index fd9b5765fa2f2..9b9df36210a79 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -495,21 +495,21 @@ func @null_non_llvm_type() {
func @nvvm_invalid_shfl_pred_1(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
// expected-error at +1 {{expected return type to be a two-element struct with i1 as the second element}}
- %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32
+ %0 = nvvm.shfl.sync "bfly" %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> i32
}
// -----
func @nvvm_invalid_shfl_pred_2(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
// expected-error at +1 {{expected return type to be a two-element struct with i1 as the second element}}
- %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.struct<(i32)>
+ %0 = nvvm.shfl.sync "bfly" %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32)>
}
// -----
func @nvvm_invalid_shfl_pred_3(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
// expected-error at +1 {{expected return type to be a two-element struct with i1 as the second element}}
- %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.struct<(i32, i32)>
+ %0 = nvvm.shfl.sync "bfly" %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i32)>
}
// -----
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 638a5ab47dd0d..e2ca4d71bbe25 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -37,20 +37,26 @@ func @llvm.nvvm.barrier0() {
func @nvvm_shfl(
%arg0 : i32, %arg1 : i32, %arg2 : i32,
%arg3 : i32, %arg4 : f32) -> i32 {
- // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i32
- %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 : i32
- // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32
- %1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 : f32
+ // CHECK: nvvm.shfl.sync "bfly" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i32 -> i32
+ %0 = nvvm.shfl.sync "bfly" %arg0, %arg3, %arg1, %arg2 : i32 -> i32
+ // CHECK: nvvm.shfl.sync "bfly" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 -> f32
+ %1 = nvvm.shfl.sync "bfly" %arg0, %arg4, %arg1, %arg2 : f32 -> f32
+ // CHECK: nvvm.shfl.sync "up" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 -> f32
+ %2 = nvvm.shfl.sync "up" %arg0, %arg4, %arg1, %arg2 : f32 -> f32
+ // CHECK: nvvm.shfl.sync "down" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 -> f32
+ %3 = nvvm.shfl.sync "down" %arg0, %arg4, %arg1, %arg2 : f32 -> f32
+ // CHECK: nvvm.shfl.sync "idx" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 -> f32
+ %4 = nvvm.shfl.sync "idx" %arg0, %arg4, %arg1, %arg2 : f32 -> f32
llvm.return %0 : i32
}
func @nvvm_shfl_pred(
%arg0 : i32, %arg1 : i32, %arg2 : i32,
%arg3 : i32, %arg4 : f32) -> !llvm.struct<(i32, i1)> {
- // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.struct<(i32, i1)>
- %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.struct<(i32, i1)>
- // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.struct<(f32, i1)>
- %1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 {return_value_and_is_valid} : !llvm.struct<(f32, i1)>
+ // CHECK: nvvm.shfl.sync "bfly" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
+ %0 = nvvm.shfl.sync "bfly" %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
+ // CHECK: nvvm.shfl.sync "bfly" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
+ %1 = nvvm.shfl.sync "bfly" %arg0, %arg4, %arg1, %arg2 {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
llvm.return %0 : !llvm.struct<(i32, i1)>
}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 08859b22c0d83..a9ec2259dfcff 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -42,9 +42,21 @@ llvm.func @nvvm_shfl(
%0 : i32, %1 : i32, %2 : i32,
%3 : i32, %4 : f32) -> i32 {
// CHECK: call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
- %6 = nvvm.shfl.sync.bfly %0, %3, %1, %2 : i32
+ %6 = nvvm.shfl.sync "bfly" %0, %3, %1, %2 : i32 -> i32
// CHECK: call float @llvm.nvvm.shfl.sync.bfly.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
- %7 = nvvm.shfl.sync.bfly %0, %4, %1, %2 : f32
+ %7 = nvvm.shfl.sync "bfly" %0, %4, %1, %2 : f32 -> f32
+ // CHECK: call i32 @llvm.nvvm.shfl.sync.up.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %8 = nvvm.shfl.sync "up" %0, %3, %1, %2 : i32 -> i32
+ // CHECK: call float @llvm.nvvm.shfl.sync.up.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %9 = nvvm.shfl.sync "up" %0, %4, %1, %2 : f32 -> f32
+ // CHECK: call i32 @llvm.nvvm.shfl.sync.down.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %10 = nvvm.shfl.sync "down" %0, %3, %1, %2 : i32 -> i32
+ // CHECK: call float @llvm.nvvm.shfl.sync.down.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %11 = nvvm.shfl.sync "down" %0, %4, %1, %2 : f32 -> f32
+ // CHECK: call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %12 = nvvm.shfl.sync "idx" %0, %3, %1, %2 : i32 -> i32
+ // CHECK: call float @llvm.nvvm.shfl.sync.idx.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %13 = nvvm.shfl.sync "idx" %0, %4, %1, %2 : f32 -> f32
llvm.return %6 : i32
}
@@ -52,9 +64,21 @@ llvm.func @nvvm_shfl_pred(
%0 : i32, %1 : i32, %2 : i32,
%3 : i32, %4 : f32) -> !llvm.struct<(i32, i1)> {
// CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.bfly.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
- %6 = nvvm.shfl.sync.bfly %0, %3, %1, %2 {return_value_and_is_valid} : !llvm.struct<(i32, i1)>
+ %6 = nvvm.shfl.sync "bfly" %0, %3, %1, %2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
// CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.bfly.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
- %7 = nvvm.shfl.sync.bfly %0, %4, %1, %2 {return_value_and_is_valid} : !llvm.struct<(f32, i1)>
+ %7 = nvvm.shfl.sync "bfly" %0, %4, %1, %2 {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
+ // CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.up.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %8 = nvvm.shfl.sync "up" %0, %3, %1, %2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
+ // CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.up.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %9 = nvvm.shfl.sync "up" %0, %4, %1, %2 {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
+ // CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.down.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %10 = nvvm.shfl.sync "down" %0, %3, %1, %2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
+ // CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.down.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %11 = nvvm.shfl.sync "down" %0, %4, %1, %2 {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
+ // CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.idx.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %12 = nvvm.shfl.sync "idx" %0, %3, %1, %2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
+ // CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.idx.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %13 = nvvm.shfl.sync "idx" %0, %4, %1, %2 {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
llvm.return %6 : !llvm.struct<(i32, i1)>
}
More information about the Mlir-commits
mailing list