[Mlir-commits] [mlir] dda4b96 - [mlir] AMDGPUToROCDL: lower `amdgpu.swizzle_bitmode` (#136223)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 18 07:19:07 PDT 2025
Author: Ivan Butygin
Date: 2025-04-18T17:19:04+03:00
New Revision: dda4b968e77e1bb2c319bf2d523de3b5c4ccbb23
URL: https://github.com/llvm/llvm-project/commit/dda4b968e77e1bb2c319bf2d523de3b5c4ccbb23
DIFF: https://github.com/llvm/llvm-project/commit/dda4b968e77e1bb2c319bf2d523de3b5c4ccbb23.diff
LOG: [mlir] AMDGPUToROCDL: lower `amdgpu.swizzle_bitmode` (#136223)
Repack `amdgpu.swizzle_bitmode` arguments and lower it to
`rocdl.ds_swizzle`.
Repacking logic is follows:
* `sizeof(arg) < sizeof(i32)`: bitcast to integer and zext to i32 and
then trunc and bitcast back.
* `sizeof(arg) == sizeof(i32)`: just bitcast to i32 and back if not i32
* `sizeof(arg) > sizeof(i32)`: bitcast to `vector<Nxi32>`, extract
individual elements and do a series of `rocdl.ds_swizzle` and then
compose vector and bitcast back.
Added repacking logic to LLVM utils so it can be used elsewhere. I'm
planning to use it for `gpu.shuffle` later.
Added:
mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir
Modified:
mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
mlir/lib/Conversion/LLVMCommon/Pattern.cpp
mlir/test/Dialect/AMDGPU/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 66c8731ec2bf4..7a58e4fc2f984 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -31,6 +31,18 @@ LogicalResult oneToOneRewrite(
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
} // namespace detail
+
+/// Decomposes a `src` value into a set of values of type `dstType` through
+/// series of bitcasts and vector ops. Src and dst types are expected to be int
+/// or float types or vector types of them.
+SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc, Value src,
+ Type dstType);
+
+/// Composes a set of `src` values into a single value of type `dstType` through
+/// series of bitcasts and vector ops. Inversely to `decomposeValue`, this
+/// function is used to combine multiple values into a single value.
+Value composeValue(OpBuilder &builder, Location loc, ValueRange src,
+ Type dstType);
} // namespace LLVM
/// Base class for operation conversions targeting the LLVM IR dialect. It
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index a92ebf6d8e108..f14aa5a2e1564 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -38,7 +38,7 @@ def AMDGPU_Dialect : Dialect {
def AnyIntegerOrFloat : AnyTypeOf<[AnySignlessInteger, AnyFloat], "Integer or Float">;
def AnyIntegerOrFloatOr1DVector :
- AnyTypeOf<[AnyIntegerOrFloat, VectorOfRankAndType<[1], [AnyIntegerOrFloat]>]>;
+ AnyTypeOf<[AnyIntegerOrFloat, FixedVectorOfRankAndType<[1], [AnyIntegerOrFloat]>]>;
//===----------------------------------------------------------------------===//
// AMDGPU general attribute definitions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 5f697bdeef566..91dbc2de65c4e 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1377,6 +1377,39 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
}
};
+struct AMDGPUSwizzleBitModeLowering
+ : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Type i32 = rewriter.getI32Type();
+ Value src = adaptor.getSrc();
+ SmallVector<Value> decomposed =
+ LLVM::decomposeValue(rewriter, loc, src, i32);
+ unsigned andMask = op.getAndMask();
+ unsigned orMask = op.getOrMask();
+ unsigned xorMask = op.getXorMask();
+
+ // bit 15 is 0 for the BitMode swizzle.
+ // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
+ unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
+ Value maskValue = createI32Constant(rewriter, loc, mask);
+ SmallVector<Value> swizzled;
+ for (Value v : decomposed) {
+ Value res =
+ rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue);
+ swizzled.emplace_back(res);
+ }
+
+ Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
using Base::Base;
@@ -1444,4 +1477,5 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
GatherToLDSOpLowering>(converter, chipset);
+ patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 32bfd72475569..1ae99561e9d1b 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -381,3 +381,96 @@ LogicalResult LLVM::detail::oneToOneRewrite(
rewriter.replaceOp(op, results);
return success();
}
+
+static unsigned getBitWidth(Type type) {
+ if (type.isIntOrFloat())
+ return type.getIntOrFloatBitWidth();
+
+ auto vec = cast<VectorType>(type);
+ assert(!vec.isScalable() && "scalable vectors are not supported");
+ return vec.getNumElements() * getBitWidth(vec.getElementType());
+}
+
+static Value createI32Constant(OpBuilder &builder, Location loc,
+ int32_t value) {
+ Type i32 = builder.getI32Type();
+ return builder.create<LLVM::ConstantOp>(loc, i32, value);
+}
+
+SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
+ Value src, Type dstType) {
+ Type srcType = src.getType();
+ if (srcType == dstType)
+ return {src};
+
+ unsigned srcBitWidth = getBitWidth(srcType);
+ unsigned dstBitWidth = getBitWidth(dstType);
+ if (srcBitWidth == dstBitWidth) {
+ Value cast = builder.create<LLVM::BitcastOp>(loc, dstType, src);
+ return {cast};
+ }
+
+ if (dstBitWidth > srcBitWidth) {
+ auto smallerInt = builder.getIntegerType(srcBitWidth);
+ if (srcType != smallerInt)
+ src = builder.create<LLVM::BitcastOp>(loc, smallerInt, src);
+
+ auto largerInt = builder.getIntegerType(dstBitWidth);
+ Value res = builder.create<LLVM::ZExtOp>(loc, largerInt, src);
+ return {res};
+ }
+ assert(srcBitWidth % dstBitWidth == 0 &&
+ "src bit width must be a multiple of dst bit width");
+ int64_t numElements = srcBitWidth / dstBitWidth;
+ auto vecType = VectorType::get(numElements, dstType);
+
+ src = builder.create<LLVM::BitcastOp>(loc, vecType, src);
+
+ SmallVector<Value> res;
+ for (auto i : llvm::seq(numElements)) {
+ Value idx = createI32Constant(builder, loc, i);
+ Value elem = builder.create<LLVM::ExtractElementOp>(loc, src, idx);
+ res.emplace_back(elem);
+ }
+
+ return res;
+}
+
+Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
+ Type dstType) {
+ assert(!src.empty() && "src range must not be empty");
+ if (src.size() == 1) {
+ Value res = src.front();
+ if (res.getType() == dstType)
+ return res;
+
+ unsigned srcBitWidth = getBitWidth(res.getType());
+ unsigned dstBitWidth = getBitWidth(dstType);
+ if (dstBitWidth < srcBitWidth) {
+ auto largerInt = builder.getIntegerType(srcBitWidth);
+ if (res.getType() != largerInt)
+ res = builder.create<LLVM::BitcastOp>(loc, largerInt, res);
+
+ auto smallerInt = builder.getIntegerType(dstBitWidth);
+ res = builder.create<LLVM::TruncOp>(loc, smallerInt, res);
+ }
+
+ if (res.getType() != dstType)
+ res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
+
+ return res;
+ }
+
+ int64_t numElements = src.size();
+ auto srcType = VectorType::get(numElements, src.front().getType());
+ Value res = builder.create<LLVM::PoisonOp>(loc, srcType);
+ for (auto &&[i, elem] : llvm::enumerate(src)) {
+ Value idx = createI32Constant(builder, loc, i);
+ res = builder.create<LLVM::InsertElementOp>(loc, srcType, res, elem, idx);
+ }
+
+ if (res.getType() != dstType)
+ res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
+
+ return res;
+}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir b/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir
new file mode 100644
index 0000000000000..ef439efde1bd0
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt -convert-amdgpu-to-rocdl --canonicalize %s | FileCheck %s
+
+// CHECK-LABEL: func @test_swizzle_i32
+// CHECK-SAME: (%[[ARG0:.*]]: i32)
+func.func @test_swizzle_i32(%arg0 : i32) -> i32 {
+// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[ARG0]], %[[C]] : (i32, i32) -> i32
+// CHECK: return %[[RES]] : i32
+ %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: func @test_swizzle_f32
+// CHECK-SAME: (%[[ARG0:.*]]: f32)
+func.func @test_swizzle_f32(%arg0 : f32) -> f32 {
+// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
+// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[CAST]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
+// CHECK: return %[[RES_CAST]] : f32
+ %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func @test_swizzle_f16
+// CHECK-SAME: (%[[ARG0:.*]]: f16)
+func.func @test_swizzle_f16(%arg0 : f16) -> f16 {
+// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
+// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
+// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[ZEXT]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
+// CHECK: return %[[RES_CAST]] : f16
+ %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f16
+ return %0 : f16
+}
+
+// CHECK-LABEL: func @test_swizzle_2xi32
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>)
+func.func @test_swizzle_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
+// CHECK-DAG: %[[V1:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[E0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[E1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: return %[[V3]] : vector<2xi32>
+ %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func @test_swizzle_4xf16
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>)
+func.func @test_swizzle_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
+// CHECK-DAG: %[[V1:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32>
+// CHECK: %[[E0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[E1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[V3]] : vector<2xi32> to vector<4xf16>
+// CHECK: return %[[CAST2]] : vector<4xf16>
+ %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<4xf16>
+ return %0 : vector<4xf16>
+}
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 40f98ff85688c..73306ba6b3f93 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -154,7 +154,15 @@ func.func @fat_raw_buffer_cast_stripping_offset_affine_map(%m: memref<8xi32, aff
// -----
func.func @swizzle_invalid_type(%arg0 : si32) -> si32 {
- // expected-error at +1 {{amdgpu.swizzle_bitmode' op operand #0 must be Integer or Float or vector of Integer or Float values of ranks 1}}
+ // expected-error at +1 {{'amdgpu.swizzle_bitmode' op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1}}
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : si32
func.return %0 : si32
}
+
+// -----
+
+func.func @swizzle_scalable_vec(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
+ // expected-error at +1 {{'amdgpu.swizzle_bitmode' op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1}}
+ %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<[4]xf32>
+ func.return %0 : vector<[4]xf32>
+}
More information about the Mlir-commits
mailing list