[Mlir-commits] [mlir] [mlir] AMDGPUToROCDL: lower `amdgpu.swizzle_bitmode` (PR #136223)
Ivan Butygin
llvmlistbot at llvm.org
Thu Apr 17 16:23:13 PDT 2025
https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/136223
Repack `amdgpu.swizzle_bitmode` arguments and lower it to `rocdl.ds_swizzle`.
Repacking logic is follows:
* `sizeof(arg) < sizeof(i32)`: bitcast to integer and zext to i32 and then trunc and bitcast back.
* `sizeof(arg) == sizeof(i32)`: just bitcast to i32 and back if not i32
* `sizeof(arg) > sizeof(i32)`: bitcast to `vector<Nxi32>`, extract individual elements and do a series of `rocdl.ds_swizzle` and then compose vector and bitcast back.
Added repacking logic to LLVM utils so it can be used elsewhere. I'm planning to use it for `gpu.shuffle` later.
>From 0916901c87442ddf2bf529870ef486f16a16b4b4 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 13 Apr 2025 06:26:34 +0200
Subject: [PATCH] [mlir] AMDGPUToROCDL: lower `amdgpu.swizzle_bitmode`
Repack `amdgpu.swizzle_bitmode` arguments and lower it to `rocdl.ds_swizzle`.
Repacking logic is follows:
* `sizeof(arg) < sizeof(i32)`: bitcast to integer and zext to i32 and then trunc and bitcast back.
* `sizeof(arg) == sizeof(i32)`: just bitcast to i32 and back if not i32
* `sizeof(arg) > sizeof(i32)`: bitcast to `vector<Nxi32>`, extract individual elements and do a series of `rocdl.ds_swizzle` and then compose vector and bitcast back.
Added repacking logic to LLVM utils so it can be used elsewhere. I'm planning to use it for `gpu.shuffle` later.
---
.../mlir/Conversion/LLVMCommon/Pattern.h | 12 +++
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 33 +++++++
mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 92 +++++++++++++++++++
.../Conversion/AMDGPUToROCDL/swizzle.mlir | 75 +++++++++++++++
4 files changed, 212 insertions(+)
create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 66c8731ec2bf4..7a58e4fc2f984 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -31,6 +31,18 @@ LogicalResult oneToOneRewrite(
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
} // namespace detail
+
+/// Decomposes a `src` value into a set of values of type `dstType` through
+/// series of bitcasts and vector ops. Src and dst types are expected to be int
+/// or float types or vector types of them.
+SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc, Value src,
+ Type dstType);
+
+/// Composes a set of `src` values into a single value of type `dstType` through
+/// series of bitcasts and vector ops. Inversely to `decomposeValue`, this
+/// function is used to combine multiple values into a single value.
+Value composeValue(OpBuilder &builder, Location loc, ValueRange src,
+ Type dstType);
} // namespace LLVM
/// Base class for operation conversions targeting the LLVM IR dialect. It
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 5f697bdeef566..5c4c95699142f 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1377,6 +1377,38 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
}
};
+struct AMDGPUSwizzleBitModeLowering
+ : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Type i32 = rewriter.getI32Type();
+ Value src = adaptor.getSrc();
+ SmallVector<Value> decomposed =
+ LLVM::decomposeValue(rewriter, loc, src, i32);
+ unsigned andMask = op.getAndMask();
+ unsigned orMask = op.getOrMask();
+ unsigned xorMask = op.getXorMask();
+
+ // bit 15 is 0 for the BitMode swizzle.
+ unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
+ Value maskValue = createI32Constant(rewriter, loc, mask);
+ SmallVector<Value> swizzled;
+ for (Value v : decomposed) {
+ Value res =
+ rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue);
+ swizzled.emplace_back(res);
+ }
+
+ Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
using Base::Base;
@@ -1444,4 +1476,5 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
GatherToLDSOpLowering>(converter, chipset);
+ patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 32bfd72475569..d2737c56369d1 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -381,3 +381,95 @@ LogicalResult LLVM::detail::oneToOneRewrite(
rewriter.replaceOp(op, results);
return success();
}
+
+static unsigned getBitWidth(Type type) {
+ if (type.isIntOrFloat())
+ return type.getIntOrFloatBitWidth();
+
+ auto vec = cast<VectorType>(type);
+ return vec.getNumElements() * getBitWidth(vec.getElementType());
+}
+
+static Value createI32Constant(OpBuilder &builder, Location loc,
+ int32_t value) {
+ Type i32 = builder.getI32Type();
+ return builder.create<LLVM::ConstantOp>(loc, i32, value);
+}
+
+SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
+ Value src, Type dstType) {
+ Type srcType = src.getType();
+ if (srcType == dstType)
+ return {src};
+
+ unsigned srcBitWidth = getBitWidth(srcType);
+ unsigned dstBitWidth = getBitWidth(dstType);
+ if (srcBitWidth == dstBitWidth) {
+ Value cast = builder.create<LLVM::BitcastOp>(loc, dstType, src);
+ return {cast};
+ }
+
+ if (dstBitWidth > srcBitWidth) {
+ auto smallerInt = builder.getIntegerType(srcBitWidth);
+ if (srcType != smallerInt)
+ src = builder.create<LLVM::BitcastOp>(loc, smallerInt, src);
+
+ auto largerInt = builder.getIntegerType(dstBitWidth);
+ Value res = builder.create<LLVM::ZExtOp>(loc, largerInt, src);
+ return {res};
+ }
+ assert(srcBitWidth % dstBitWidth == 0 &&
+ "src bit width must be a multiple of dst bit width");
+ int64_t numElements = srcBitWidth / dstBitWidth;
+ auto vecType = VectorType::get(numElements, dstType);
+
+ src = builder.create<LLVM::BitcastOp>(loc, vecType, src);
+
+ SmallVector<Value> res;
+ for (auto i : llvm::seq<int64_t>(0, numElements)) {
+ Value idx = createI32Constant(builder, loc, i);
+ Value elem = builder.create<LLVM::ExtractElementOp>(loc, src, idx);
+ res.emplace_back(elem);
+ }
+
+ return res;
+}
+
+Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
+ Type dstType) {
+ assert(!src.empty() && "src range must not be empty");
+ if (src.size() == 1) {
+ Value res = src.front();
+ if (res.getType() == dstType)
+ return res;
+
+ unsigned srcBitWidth = getBitWidth(res.getType());
+ unsigned dstBitWidth = getBitWidth(dstType);
+ if (dstBitWidth < srcBitWidth) {
+ auto largerInt = builder.getIntegerType(srcBitWidth);
+ if (res.getType() != largerInt)
+ res = builder.create<LLVM::BitcastOp>(loc, largerInt, res);
+
+ auto smallerInt = builder.getIntegerType(dstBitWidth);
+ res = builder.create<LLVM::TruncOp>(loc, smallerInt, res);
+ }
+
+ if (res.getType() != dstType)
+ res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
+
+ return res;
+ }
+
+ int64_t numElements = src.size();
+ auto srcType = VectorType::get(numElements, src.front().getType());
+ Value res = builder.create<LLVM::PoisonOp>(loc, srcType);
+ for (auto &&[i, elem] : llvm::enumerate(src)) {
+ Value idx = createI32Constant(builder, loc, i);
+ res = builder.create<LLVM::InsertElementOp>(loc, srcType, res, elem, idx);
+ }
+
+ if (res.getType() != dstType)
+ res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
+
+ return res;
+}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir b/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir
new file mode 100644
index 0000000000000..ef439efde1bd0
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt -convert-amdgpu-to-rocdl --canonicalize %s | FileCheck %s
+
+// CHECK-LABEL: func @test_swizzle_i32
+// CHECK-SAME: (%[[ARG0:.*]]: i32)
+func.func @test_swizzle_i32(%arg0 : i32) -> i32 {
+// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[ARG0]], %[[C]] : (i32, i32) -> i32
+// CHECK: return %[[RES]] : i32
+ %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: func @test_swizzle_f32
+// CHECK-SAME: (%[[ARG0:.*]]: f32)
+func.func @test_swizzle_f32(%arg0 : f32) -> f32 {
+// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
+// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[CAST]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
+// CHECK: return %[[RES_CAST]] : f32
+ %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func @test_swizzle_f16
+// CHECK-SAME: (%[[ARG0:.*]]: f16)
+func.func @test_swizzle_f16(%arg0 : f16) -> f16 {
+// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
+// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
+// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[ZEXT]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
+// CHECK: return %[[RES_CAST]] : f16
+ %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f16
+ return %0 : f16
+}
+
+// CHECK-LABEL: func @test_swizzle_2xi32
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>)
+func.func @test_swizzle_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
+// CHECK-DAG: %[[V1:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[E0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[E1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: return %[[V3]] : vector<2xi32>
+ %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func @test_swizzle_4xf16
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>)
+func.func @test_swizzle_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
+// CHECK-DAG: %[[V1:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32>
+// CHECK: %[[E0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[E1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[V3]] : vector<2xi32> to vector<4xf16>
+// CHECK: return %[[CAST2]] : vector<4xf16>
+ %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<4xf16>
+ return %0 : vector<4xf16>
+}
More information about the Mlir-commits
mailing list