[Mlir-commits] [mlir] 1896096 - [mlir][ROCM] Add Wave/Warp shuffle lowering and op for ROCM.

Anush Elangovan llvmlistbot at llvm.org
Thu Aug 24 17:36:55 PDT 2023


Author: Stanley Winata
Date: 2023-08-24T17:35:34-07:00
New Revision: 1896096002b75b50d46ee0043c20e90c7e27604a

URL: https://github.com/llvm/llvm-project/commit/1896096002b75b50d46ee0043c20e90c7e27604a
DIFF: https://github.com/llvm/llvm-project/commit/1896096002b75b50d46ee0043c20e90c7e27604a.diff

LOG: [mlir][ROCM] Add Wave/Warp shuffle lowering and op for ROCM.

Reduction is heavily used for many DL workload especially with
softmax/Attention layers. Wave/Warp shuffle and reduction is known to be
a speedy/efficient way to do these reductions.

In this patch we introduce AMD shuffle intrinsic Ops to ROCDL, along with it's corresponding lowering from gpu.shuffle. This should speed up a lot of DL workloads on ROCM backend. Currently, we have support for xor and idx, which are the more common ones. In the future, we plan on adding support for Down and Up, as well as using the ds_swizzle to further enhance it's performance when width and offsets are constant.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D158684

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
    mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
    mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
    mlir/test/Target/LLVMIR/rocdl.mlir

Removed: 
    mlir/test/Conversion/GPUToROCDL/invalid.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 0738f09ac4164a..13515cd669570e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -108,6 +108,35 @@ class ROCDL_MbcntOp<string mnemonic> :
 def ROCDL_MbcntLoOp : ROCDL_MbcntOp<"lo">;
 def ROCDL_MbcntHiOp : ROCDL_MbcntOp<"hi">;
 
+def ROCDL_DsSwizzleOp : 
+ROCDL_Op<"ds_swizzle">,
+Results<(outs I32:$res)>,
+Arguments<(ins I32:$src,
+               I32:$offset)>
+{
+  string llvmBuilder = [{
+    $res = createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_ds_swizzle, {$src, $offset});
+  }];
+  let assemblyFormat = [{
+    $src `,` $offset  attr-dict `:` `(` type($src) `,` type($offset) `)` `->` type($res)
+   }];
+}
+
+def ROCDL_DsBpermuteOp : 
+ROCDL_Op<"ds_bpermute">,
+Results<(outs I32:$res)>,
+Arguments<(ins I32:$index,
+               I32:$src)>
+{
+  string llvmBuilder = [{
+    $res = createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_ds_bpermute, {$index, $src});
+  }];
+  let assemblyFormat = [{
+    $index `,` $src  attr-dict `:` `(` type($index) `,` type($src) `)` `->` type($res)
+   }];
+}
+
+
 //===----------------------------------------------------------------------===//
 // Thread index and Block index
 

diff  --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 386ac3c9db3d11..7a90f228f3cf73 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -64,6 +64,18 @@ static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
   return canBeBare;
 }
 
+Value getLaneId(ConversionPatternRewriter &rewriter, Location loc,
+                const unsigned indexBitwidth) {
+  auto int32Type = IntegerType::get(rewriter.getContext(), 32);
+  Value zero = rewriter.createOrFold<arith::ConstantIntOp>(loc, 0, 32);
+  Value minus1 = rewriter.createOrFold<arith::ConstantIntOp>(loc, -1, 32);
+  Value mbcntLo = rewriter.create<ROCDL::MbcntLoOp>(loc, int32Type,
+                                                    ValueRange{minus1, zero});
+  Value laneId = rewriter.create<ROCDL::MbcntHiOp>(loc, int32Type,
+                                                   ValueRange{minus1, mbcntLo});
+  return laneId;
+}
+
 namespace {
 struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
   using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
@@ -98,6 +110,80 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
   }
 };
 
+struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
+  using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
+
+  /// Lowers a shuffle to the corresponding ROCDL ops.
+  ///
+  /// Use the `width` argument to see if src lane is participating.
+  /// If not the dstLane would be itself.
+  ///
+  ///  Shuffle with DS Bpermute:
+  ///   let shflMode = [xor, up, down, idx]
+  ///   let width = 32(usually warpsize), step = [1, 2, 4, 8, 16, ... , width].
+  ///   1. curLaneId = using mbcnt.lo + mbcnt.hi
+  ///   2. widthOrZeroIfOutside = (curLaneId + width) & -width
+  ///   3. dstLane = shflMode(curLaneId, step)
+  ///   4. isActiveSrcLane = dstLane < isActiveSrcLane
+  ///   5. dstLane = isActiveSrcLane ? dstLane : curLaneId
+  ///   6. dwordAlignedDstLane = dstLane * 4 or dstLane << 2.
+  ///   7. bpermute(dwordAlignedDstLane, shfl_value).
+  ///
+  LogicalResult
+  matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    // TODO: Add support for non 32-bit shuffle values.
+    if (adaptor.getValue().getType().getIntOrFloatBitWidth() != 32)
+      return failure();
+    const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
+    Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);
+
+    auto int32Type = IntegerType::get(rewriter.getContext(), 32);
+    Value width = adaptor.getWidth();
+    Value zero = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 0);
+    Value negwidth = rewriter.create<LLVM::SubOp>(loc, int32Type, zero, width);
+    Value add = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId, width);
+    Value widthOrZeroIfOutside =
+        rewriter.create<LLVM::AndOp>(loc, int32Type, add, negwidth);
+    Value dstLane;
+    // TODO: Add support for gpu::ShuffleMode::UP and gpu::ShuffleMode::DOWN.
+    // TODO: Use ds_swizzle for XOR when step/offsets are constants for better
+    // perf.
+    switch (op.getMode()) {
+    case gpu::ShuffleMode::XOR:
+      dstLane = rewriter.create<LLVM::XOrOp>(loc, int32Type, srcLaneId,
+                                             adaptor.getOffset());
+      break;
+    case gpu::ShuffleMode::IDX:
+      dstLane = adaptor.getOffset();
+      break;
+    default:
+      return failure();
+    }
+    Value isActiveSrcLane = rewriter.create<LLVM::ICmpOp>(
+        loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
+    Value selectDstLane = rewriter.create<LLVM::SelectOp>(loc, isActiveSrcLane,
+                                                          dstLane, srcLaneId);
+    Value two = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 2);
+    Value dwordAlignedDstLane =
+        rewriter.create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
+    Value initShflValue = adaptor.getValue();
+    if (adaptor.getValue().getType().isF32()) {
+      initShflValue =
+          rewriter.create<LLVM::BitcastOp>(loc, int32Type, initShflValue);
+    }
+    Value shflValue = rewriter.create<ROCDL::DsBpermuteOp>(
+        loc, int32Type, dwordAlignedDstLane, initShflValue);
+    if (adaptor.getValue().getType().isF32()) {
+      shflValue = rewriter.create<LLVM::BitcastOp>(
+          loc, adaptor.getValue().getType(), shflValue);
+    }
+    rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
+    return success();
+  }
+};
+
 /// Import the GPU Ops to ROCDL Patterns.
 #include "GPUToROCDL.cpp.inc"
 
@@ -278,7 +364,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
     patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/4);
   }
 
-  patterns.add<GPULaneIdOpToROCDL>(converter);
+  patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
 
   populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
                                    "__ocml_fabs_f64");

diff  --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index a3a73e281db710..8f27c4b0bcc74b 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -524,3 +524,92 @@ gpu.module @module {
   }
 }
 
+// -----
+
+gpu.module @test_module {
+  // CHECK-LABEL: func @gpu_all_reduce_op()
+  gpu.func @gpu_all_reduce_op() {
+    %arg0 = arith.constant 1.0 : f32
+    // TODO: Check full IR expansion once lowering has settled.
+    // CHECK: llvm.add
+    // CHECK: llvm.and
+    // CHECK: llvm.xor
+    // CHECK: llvm.icmp "slt"
+    // CHECK: llvm.select
+    // CHECK: llvm.shl
+    // CHECK: rocdl.ds_bpermute {{.*}}
+    // CHECK: rocdl.barrier
+    // CHECK: llvm.bitcast
+    // CHECK: llvm.fadd
+    %result = gpu.all_reduce add %arg0 uniform {} : (f32) -> (f32)
+
+    gpu.return
+  }
+}
+
+
+// -----
+
+gpu.module @test_module {
+  // CHECK-LABEL: func @gpu_all_reduce_region()
+  gpu.func @gpu_all_reduce_region() {
+    %arg0 = arith.constant 1 : i32
+    // TODO: Check full IR expansion once lowering has settled.
+    // CHECK: llvm.add
+    // CHECK: llvm.and
+    // CHECK: llvm.xor
+    // CHECK: llvm.icmp "slt"
+    // CHECK: llvm.select
+    // CHECK: llvm.shl
+    // CHECK: rocdl.ds_bpermute {{.*}}
+    // CHECK: rocdl.barrier
+    %result = gpu.all_reduce %arg0 uniform {
+    ^bb(%lhs : i32, %rhs : i32):
+      %xor = arith.xori %lhs, %rhs : i32
+      "gpu.yield"(%xor) : (i32) -> ()
+    } : (i32) -> (i32)
+    gpu.return
+  }
+}
+
+// -----
+
+gpu.module @test_module {
+  // CHECK-LABEL: func @gpu_shuffle()
+  func.func @gpu_shuffle() -> (f32, f32) {
+    // CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
+    %arg0 = arith.constant 1.0 : f32
+    // CHECK: %[[#OFFSET:]] = llvm.mlir.constant(4 : i32) : i32
+    %arg1 = arith.constant 4 : i32
+    // CHECK: %[[#WIDTH:]] = llvm.mlir.constant(23 : i32) : i32
+    %arg2 = arith.constant 23 : i32
+    // CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
+    // CHECK: %[[#ZERO:]] = llvm.mlir.constant(0 : i32) : i32
+    // CHECK: %[[#NEG_WIDTH:]] = llvm.sub %[[#ZERO]], %[[#WIDTH]] : i32
+    // CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
+    // CHECK: %[[#WARP_OR_ZERO:]] = llvm.and %[[#ADD]], %[[#NEG_WIDTH]] : i32
+    // CHECK: %[[#XOR:]] = llvm.xor %[[#LANE_ID]], %{{.*}} : i32
+    // CHECK: %[[#CMP:]] = llvm.icmp "slt" %[[#XOR]], %[[#WARP_OR_ZERO]] : i32
+    // CHECK: %[[#DST_LANE:]] = llvm.select %[[#CMP]], %[[#XOR]], %{{.*}} : i1, i32
+    // CHECK: %[[#TWO:]] = llvm.mlir.constant(2 : i32) : i32
+    // CHECK: %[[#ALIGNED_DST_LANE:]] = llvm.shl %[[#DST_LANE]], %[[#TWO]] : i32
+    // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
+    // CHECK: %[[#PERMUTE:]] = rocdl.ds_bpermute %[[#ALIGNED_DST_LANE]], %[[#CAST_VALUE]] : (i32, i32) -> i32
+    // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32
+    %shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : f32
+    // CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
+    // CHECK: %[[#ZERO:]] = llvm.mlir.constant(0 : i32) : i32
+    // CHECK: %[[#NEG_WIDTH:]] = llvm.sub %[[#ZERO]], %[[#WIDTH]] : i32
+    // CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
+    // CHECK: %[[#WARP_OR_ZERO:]] = llvm.and %[[#ADD]], %[[#NEG_WIDTH]] : i32
+    // CHECK: %[[#CMP:]] = llvm.icmp "slt" %[[#OFFSET]], %[[#WARP_OR_ZERO]] : i32
+    // CHECK: %[[#DST_LANE:]] = llvm.select %[[#CMP]], %[[#OFFSET]], %{{.*}} : i1, i32
+    // CHECK: %[[#TWO:]] = llvm.mlir.constant(2 : i32) : i32
+    // CHECK: %[[#ALIGNED_DST_LANE:]] = llvm.shl %[[#DST_LANE]], %[[#TWO]] : i32
+    // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
+    // CHECK: %[[#PERMUTE:]] = rocdl.ds_bpermute %[[#ALIGNED_DST_LANE]], %[[#CAST_VALUE]] : (i32, i32) -> i32
+    // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32
+    %shfli, %predi = gpu.shuffle idx %arg0, %arg1, %arg2 : f32
+    func.return %shfl, %shfli : f32, f32
+  }
+}
\ No newline at end of file

diff  --git a/mlir/test/Conversion/GPUToROCDL/invalid.mlir b/mlir/test/Conversion/GPUToROCDL/invalid.mlir
deleted file mode 100644
index a4b59c5d218e1f..00000000000000
--- a/mlir/test/Conversion/GPUToROCDL/invalid.mlir
+++ /dev/null
@@ -1,17 +0,0 @@
-// RUN: mlir-opt %s -convert-gpu-to-rocdl -split-input-file --verify-diagnostics
-
-
-// Demonstrate the need to register the cf and memref dialect as dependent.
-// CHECK-LABEL: @dependentDialect
-gpu.module @module {
-  gpu.func @dependentDialect() {
-    %arg0 = arith.constant 1 : i32
-    // expected-error at +1 {{failed to legalize operation 'gpu.shuffle' that was explicitly marked illega}}
-    %result = gpu.all_reduce %arg0 uniform {
-    ^bb(%lhs : i32, %rhs : i32):
-      %xor = arith.xori %lhs, %rhs : i32
-      "gpu.yield"(%xor) : (i32) -> ()
-    } : (i32) -> (i32)
-    gpu.return
-  }
-}

diff  --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index dfee70b99d3910..83467553ad5f4c 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -66,6 +66,22 @@ llvm.func @rocdl.lane_id() -> i32 {
   llvm.return %3 : i32
 }
 
+llvm.func @rocdl.swizzle(%src : i32) -> i32 {
+  // CHECK-LABEL: rocdl.swizzle
+  // CHECK: call i32 @llvm.amdgcn.ds.swizzle
+  %offset = llvm.mlir.constant(100 : i32) : i32
+  %0 = rocdl.ds_swizzle %src, %offset : (i32, i32) -> i32
+  llvm.return %0 : i32
+}
+
+llvm.func @rocdl.bpermute(%src : i32) -> i32 {
+  // CHECK-LABEL: rocdl.bpermute
+  // CHECK: call i32 @llvm.amdgcn.ds.bpermute
+  %index = llvm.mlir.constant(10 : i32) : i32
+  %0 = rocdl.ds_bpermute %index, %src : (i32, i32) -> i32
+  llvm.return %0 : i32
+}
+
 llvm.func @rocdl.barrier() {
   // CHECK:      fence syncscope("workgroup") release
   // CHECK-NEXT: call void @llvm.amdgcn.s.barrier()


        


More information about the Mlir-commits mailing list