[Mlir-commits] [mlir] 6ddc03d - [mlir][spirv][webgpu] Add lowering of IAddCarry to IAdd (#68495)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 19 15:59:08 PDT 2023
Author: Finn Plummer
Date: 2023-10-19T15:59:04-07:00
New Revision: 6ddc03d97c0de96691ed27f05c9b30869051ce06
URL: https://github.com/llvm/llvm-project/commit/6ddc03d97c0de96691ed27f05c9b30869051ce06
DIFF: https://github.com/llvm/llvm-project/commit/6ddc03d97c0de96691ed27f05c9b30869051ce06.diff
LOG: [mlir][spirv][webgpu] Add lowering of IAddCarry to IAdd (#68495)
WebGPU does not currently support extended arithmetic, this is an issue
when we want to lower from SPIR-V. This commit adds a pattern to
transform and emulate spirv.IAddCarry with spirv.IAdd operations
Fixes #65154
Added:
mlir/test/mlir-vulkan-runner/iaddcarry_extended.mlir
Modified:
mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
index 44fea86785593e9..21de1c9e867c04e 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
@@ -167,6 +167,42 @@ using ExpandSMulExtendedPattern =
using ExpandUMulExtendedPattern =
ExpandMulExtendedPattern<UMulExtendedOp, false>;
+struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
+ using OpRewritePattern<IAddCarryOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IAddCarryOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ Value lhs = op.getOperand1();
+ Value rhs = op.getOperand2();
+
+ // Currently, WGSL only supports 32-bit integer types. Any other integer
+ // types should already have been promoted/demoted to i32.
+ Type argTy = lhs.getType();
+ auto elemTy = cast<IntegerType>(getElementTypeOrSelf(argTy));
+ if (elemTy.getIntOrFloatBitWidth() != 32)
+ return rewriter.notifyMatchFailure(
+ loc,
+ llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
+
+ Value one =
+ rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 1));
+ Value zero =
+ rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 0));
+
+ // Calculate the carry by checking if the addition resulted in an overflow.
+ Value out = rewriter.create<IAddOp>(loc, lhs, rhs);
+ Value cmp = rewriter.create<ULessThanOp>(loc, out, lhs);
+ Value carry = rewriter.create<SelectOp>(loc, cmp, one, zero);
+
+ Value add = rewriter.create<CompositeConstructOp>(
+ loc, op->getResultTypes().front(), llvm::ArrayRef({out, carry}));
+
+ rewriter.replaceOp(op, add);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
@@ -191,8 +227,12 @@ void populateSPIRVExpandExtendedMultiplicationPatterns(
RewritePatternSet &patterns) {
// WGSL currently does not support extended multiplication ops, see:
// https://github.com/gpuweb/gpuweb/issues/1565.
- patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern>(
- patterns.getContext());
+ patterns.add<
+ // clang-format off
+ ExpandSMulExtendedPattern,
+ ExpandUMulExtendedPattern,
+ ExpandAddCarryPattern
+ >(patterns.getContext());
}
} // namespace spirv
} // namespace mlir
diff --git a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
index 91eeeda6ec54c64..1ec4e5e4f9664b8 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
@@ -145,4 +145,41 @@ spirv.func @smul_extended_i16(%arg : i16) -> !spirv.struct<(i16, i16)> "None" {
spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
}
+// CHECK-LABEL: func @iaddcarry_i32
+// CHECK-SAME: ([[A:%.+]]: i32, [[B:%.+]]: i32)
+// CHECK-NEXT: [[ONE:%.+]] = spirv.Constant 1 : i32
+// CHECK-NEXT: [[ZERO:%.+]] = spirv.Constant 0 : i32
+// CHECK-NEXT: [[OUT:%.+]] = spirv.IAdd [[A]], [[B]]
+// CHECK-NEXT: [[CMP:%.+]] = spirv.ULessThan [[OUT]], [[A]]
+// CHECK-NEXT: [[CARRY:%.+]] = spirv.Select [[CMP]], [[ONE]], [[ZERO]]
+// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[OUT]], [[CARRY]] : (i32, i32) -> !spirv.struct<(i32, i32)>
+// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(i32, i32)>
+spirv.func @iaddcarry_i32(%a : i32, %b : i32) -> !spirv.struct<(i32, i32)> "None" {
+ %0 = spirv.IAddCarry %a, %b : !spirv.struct<(i32, i32)>
+ spirv.ReturnValue %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: func @iaddcarry_vector_i32
+// CHECK-SAME: ([[A:%.+]]: vector<3xi32>, [[B:%.+]]: vector<3xi32>)
+// CHECK-NEXT: [[ONE:%.+]] = spirv.Constant dense<1> : vector<3xi32>
+// CHECK-NEXT: [[ZERO:%.+]] = spirv.Constant dense<0> : vector<3xi32>
+// CHECK-NEXT: [[OUT:%.+]] = spirv.IAdd [[A]], [[B]]
+// CHECK-NEXT: [[CMP:%.+]] = spirv.ULessThan [[OUT]], [[A]]
+// CHECK-NEXT: [[CARRY:%.+]] = spirv.Select [[CMP]], [[ONE]], [[ZERO]]
+// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[OUT]], [[CARRY]] : (vector<3xi32>, vector<3xi32>) -> !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+spirv.func @iaddcarry_vector_i32(%a : vector<3xi32>, %b : vector<3xi32>)
+ -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> "None" {
+ %0 = spirv.IAddCarry %a, %b : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+ spirv.ReturnValue %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+}
+
+// CHECK-LABEL: func @iaddcarry_i16
+// CHECK-NEXT: spirv.IAddCarry
+// CHECK-NEXT: spirv.ReturnValue
+spirv.func @iaddcarry_i16(%a : i16, %b : i16) -> !spirv.struct<(i16, i16)> "None" {
+ %0 = spirv.IAddCarry %a, %b : !spirv.struct<(i16, i16)>
+ spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
+}
+
} // end module
diff --git a/mlir/test/mlir-vulkan-runner/iaddcarry_extended.mlir b/mlir/test/mlir-vulkan-runner/iaddcarry_extended.mlir
new file mode 100644
index 000000000000000..9b1f1964b3f953e
--- /dev/null
+++ b/mlir/test/mlir-vulkan-runner/iaddcarry_extended.mlir
@@ -0,0 +1,68 @@
+// Make sure that addition with carry produces expected results
+// with and without expansion to primitive add/cmp ops for WebGPU.
+
+// RUN: mlir-vulkan-runner %s \
+// RUN: --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils \
+// RUN: --entry-point-result=void | FileCheck %s
+
+// RUN: mlir-vulkan-runner %s --vulkan-runner-spirv-webgpu-prepare \
+// RUN: --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils \
+// RUN: --entry-point-result=void | FileCheck %s
+
+// CHECK: [0, 42, 0, 42]
+// CHECK: [1, 0, 1, 1]
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+} {
+ gpu.module @kernels {
+ gpu.func @kernel_add(%arg0 : memref<4xi32>, %arg1 : memref<4xi32>, %arg2 : memref<4xi32>, %arg3 : memref<4xi32>)
+ kernel attributes { spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
+ %0 = gpu.block_id x
+ %lhs = memref.load %arg0[%0] : memref<4xi32>
+ %rhs = memref.load %arg1[%0] : memref<4xi32>
+ %sum, %carry = arith.addui_extended %lhs, %rhs : i32, i1
+
+ %carry_i32 = arith.extui %carry : i1 to i32
+
+ memref.store %sum, %arg2[%0] : memref<4xi32> memref.store %carry_i32, %arg3[%0] : memref<4xi32>
+ gpu.return
+ }
+ }
+
+ func.func @main() {
+ %buf0 = memref.alloc() : memref<4xi32>
+ %buf1 = memref.alloc() : memref<4xi32>
+ %buf2 = memref.alloc() : memref<4xi32>
+ %buf3 = memref.alloc() : memref<4xi32>
+ %i32_0 = arith.constant 0 : i32
+
+ // Initialize output buffers.
+ %buf4 = memref.cast %buf2 : memref<4xi32> to memref<?xi32>
+ %buf5 = memref.cast %buf3 : memref<4xi32> to memref<?xi32>
+ call @fillResource1DInt(%buf4, %i32_0) : (memref<?xi32>, i32) -> ()
+ call @fillResource1DInt(%buf5, %i32_0) : (memref<?xi32>, i32) -> ()
+
+ %idx_0 = arith.constant 0 : index
+ %idx_1 = arith.constant 1 : index
+ %idx_4 = arith.constant 4 : index
+
+ // Initialize input buffers.
+ %lhs_vals = arith.constant dense<[-1, 24, 4294967295, 43]> : vector<4xi32>
+ %rhs_vals = arith.constant dense<[1, 18, 1, 4294967295]> : vector<4xi32>
+ vector.store %lhs_vals, %buf0[%idx_0] : memref<4xi32>, vector<4xi32>
+ vector.store %rhs_vals, %buf1[%idx_0] : memref<4xi32>, vector<4xi32>
+
+ gpu.launch_func @kernels::@kernel_add
+ blocks in (%idx_4, %idx_1, %idx_1) threads in (%idx_1, %idx_1, %idx_1)
+ args(%buf0 : memref<4xi32>, %buf1 : memref<4xi32>, %buf2 : memref<4xi32>, %buf3 : memref<4xi32>)
+ %buf_sum = memref.cast %buf4 : memref<?xi32> to memref<*xi32>
+ %buf_carry = memref.cast %buf5 : memref<?xi32> to memref<*xi32>
+ call @printMemrefI32(%buf_sum) : (memref<*xi32>) -> ()
+ call @printMemrefI32(%buf_carry) : (memref<*xi32>) -> ()
+ return
+ }
+ func.func private @fillResource1DInt(%0 : memref<?xi32>, %1 : i32)
+ func.func private @printMemrefI32(%ptr : memref<*xi32>)
+}
More information about the Mlir-commits
mailing list