[Mlir-commits] [mlir] Lower shuffle to single-result form if possible. (PR #84321)

Johannes Reifferscheid llvmlistbot at llvm.org
Thu Mar 21 00:00:35 PDT 2024


https://github.com/jreiffers updated https://github.com/llvm/llvm-project/pull/84321

>From 8e9968372f00bd2996bef918c912d393d854d511 Mon Sep 17 00:00:00 2001
From: Johannes Reifferscheid <jreiffers at google.com>
Date: Thu, 21 Mar 2024 07:54:00 +0100
Subject: [PATCH] Lower shuffle to single-result form if possible.

We currently always lower shuffle to the struct-returning variant. I saw
some cases where this survived all the way through ptx, resulting in
increased register usage. The easiest fix is to simply lower to the
single-result version when the predicate is unused.
---
 .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp        | 20 +++++++----
 .../Conversion/GPUToNVVM/gpu-to-nvvm.mlir     | 36 +++++++++++++++++--
 2 files changed, 47 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index d6a5d8cd74d5f2..3814c5a6fde87a 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -155,8 +155,6 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
     auto valueTy = adaptor.getValue().getType();
     auto int32Type = IntegerType::get(rewriter.getContext(), 32);
     auto predTy = IntegerType::get(rewriter.getContext(), 1);
-    auto resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
-                                                     {valueTy, predTy});
 
     Value one = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 1);
     Value minusOne = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
@@ -176,13 +174,23 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
           rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one);
     }
 
-    auto returnValueAndIsValidAttr = rewriter.getUnitAttr();
+    bool predIsUsed = !op->getResult(1).use_empty();
+    UnitAttr returnValueAndIsValidAttr = nullptr;
+    Type resultTy = valueTy;
+    if (predIsUsed) {
+      returnValueAndIsValidAttr = rewriter.getUnitAttr();
+      resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
+                                                  {valueTy, predTy});
+    }
     Value shfl = rewriter.create<NVVM::ShflOp>(
         loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(),
         maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr);
-    Value shflValue = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0);
-    Value isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1);
-
+    Value isActiveSrcLane = nullptr;
+    if (predIsUsed) {
+      shfl = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0);
+      isActiveSrcLane =
+          rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1);
+    }
     rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
     return success();
   }
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index dd3b6c2080aa21..8877ee083286b4 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -112,7 +112,7 @@ gpu.module @test_module_3 {
 
 gpu.module @test_module_4 {
   // CHECK-LABEL: func @gpu_shuffle()
-  func.func @gpu_shuffle() -> (f32, f32, f32, f32) {
+  func.func @gpu_shuffle() -> (f32, f32, f32, f32, i1, i1, i1, i1) {
     // CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
     %arg0 = arith.constant 1.0 : f32
     // CHECK: %[[#OFFSET:]] = llvm.mlir.constant(4 : i32) : i32
@@ -143,11 +143,41 @@ gpu.module @test_module_4 {
     // CHECK: nvvm.shfl.sync idx {{.*}} {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
     %shfli, %predi = gpu.shuffle idx %arg0, %arg1, %arg2 : f32
 
-    func.return %shfl, %shflu, %shfld, %shfli : f32, f32,f32, f32
+    func.return %shfl, %shflu, %shfld, %shfli, %pred, %predu, %predd, %predi
+      : f32, f32,f32, f32, i1, i1, i1, i1
   }
-}
 
+  // CHECK-LABEL: func @gpu_shuffle_unused_pred()
+  func.func @gpu_shuffle_unused_pred() -> (f32, f32, f32, f32) {
+    // CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
+    %arg0 = arith.constant 1.0 : f32
+    // CHECK: %[[#OFFSET:]] = llvm.mlir.constant(4 : i32) : i32
+    %arg1 = arith.constant 4 : i32
+    // CHECK: %[[#WIDTH:]] = llvm.mlir.constant(23 : i32) : i32
+    %arg2 = arith.constant 23 : i32
+    // CHECK: %[[#ONE:]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK: %[[#MINUS_ONE:]] = llvm.mlir.constant(-1 : i32) : i32
+    // CHECK: %[[#THIRTY_TWO:]] = llvm.mlir.constant(32 : i32) : i32
+    // CHECK: %[[#NUM_LANES:]] = llvm.sub %[[#THIRTY_TWO]], %[[#WIDTH]] : i32
+    // CHECK: %[[#MASK:]] = llvm.lshr %[[#MINUS_ONE]], %[[#NUM_LANES]] : i32
+    // CHECK: %[[#CLAMP:]] = llvm.sub %[[#WIDTH]], %[[#ONE]] : i32
+    // CHECK: %[[#SHFL:]] = nvvm.shfl.sync bfly %[[#MASK]], %[[#VALUE]], %[[#OFFSET]], %[[#CLAMP]] : f32 -> f32
+    %shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : f32
+    // CHECK: %[[#ONE:]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK: %[[#MINUS_ONE:]] = llvm.mlir.constant(-1 : i32) : i32
+    // CHECK: %[[#THIRTY_TWO:]] = llvm.mlir.constant(32 : i32) : i32
+    // CHECK: %[[#NUM_LANES:]] = llvm.sub %[[#THIRTY_TWO]], %[[#WIDTH]] : i32
+    // CHECK: %[[#MASK:]] = llvm.lshr %[[#MINUS_ONE]], %[[#NUM_LANES]] : i32
+    // CHECK: %[[#SHFL:]] = nvvm.shfl.sync up %[[#MASK]], %[[#VALUE]], %[[#OFFSET]], %[[#NUM_LANES]] : f32 -> f32
+    %shflu, %predu = gpu.shuffle up %arg0, %arg1, %arg2 : f32
+    // CHECK: nvvm.shfl.sync down {{.*}} : f32 -> f32
+    %shfld, %predd = gpu.shuffle down %arg0, %arg1, %arg2 : f32
+    // CHECK: nvvm.shfl.sync idx {{.*}} : f32 -> f32
+    %shfli, %predi = gpu.shuffle idx %arg0, %arg1, %arg2 : f32
 
+    func.return %shfl, %shflu, %shfld, %shfli : f32, f32,f32, f32
+  }
+}
 
 gpu.module @test_module_5 {
   // CHECK-LABEL: func @gpu_sync()



More information about the Mlir-commits mailing list