[Mlir-commits] [mlir] 1b82245 - [mlir][spirv] Add smul_extended expansion for WebGPU
Jakub Kuderski
llvmlistbot at llvm.org
Thu Jan 5 17:12:47 PST 2023
Author: Jakub Kuderski
Date: 2023-01-05T20:11:47-05:00
New Revision: 1b8224537070337c2d983a204a08eb27bac1ded6
URL: https://github.com/llvm/llvm-project/commit/1b8224537070337c2d983a204a08eb27bac1ded6
DIFF: https://github.com/llvm/llvm-project/commit/1b8224537070337c2d983a204a08eb27bac1ded6.diff
LOG: [mlir][spirv] Add smul_extended expansion for WebGPU
We need this because WGSL does not support extended multiplication ops.
Fixes: https://github.com/llvm/llvm-project/issues/59563
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D141096
Added:
mlir/test/mlir-vulkan-runner/smul_extended.mlir
Modified:
mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
index 5f8426b8871a..1ed71db64b8a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
@@ -15,7 +15,9 @@
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
@@ -45,90 +47,126 @@ Attribute getScalarOrSplatAttr(Type type, int64_t value) {
return SplatElementsAttr::get(type, sizedValue);
}
+Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter,
+ Value lhs, Value rhs,
+ bool signExtendArguments) {
+ Location loc = mulOp->getLoc();
+ Type argTy = lhs.getType();
+ // Emulate 64-bit multiplication by splitting each input element of type i32
+ // into 2 16-bit digits of type i32. This is so that the intermediate
+ // multiplications and additions do not overflow. We extract these 16-bit
+ // digits from i32 vector elements by masking (low digit) and shifting right
+ // (high digit).
+ //
+ // The multiplication algorithm used is the standard (long) multiplication.
+ // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit
+ // digits.
+ // - With zero-extended arguments, we end up emitting only 4 multiplications
+ // and 4 additions after constant folding.
+ // - With sign-extended arguments, we end up emitting 8 multiplications and
+ // and 12 additions after CSE.
+ Value cstLowMask = rewriter.create<ConstantOp>(
+ loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
+ auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
+ return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
+ };
+
+ Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
+ getScalarOrSplatAttr(argTy, 16));
+ auto getHighDigit = [&rewriter, loc, cst16](Value val) {
+ return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
+ };
+
+ auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) {
+ // We only need to shift arithmetically by 15, but the extra
+ // sign-extension bit will be truncated by the logical shift, so this is
+ // fine. We do not have to introduce an extra constant since any
+ // value in [15, 32) would do.
+ return getHighDigit(
+ rewriter.create<ShiftRightArithmeticOp>(loc, val, cst16));
+ };
+
+ Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(),
+ getScalarOrSplatAttr(argTy, 0));
+
+ Value lhsLow = getLowDigit(lhs);
+ Value lhsHigh = getHighDigit(lhs);
+ Value lhsExt = signExtendArguments ? getSignDigit(lhs) : cst0;
+ Value rhsLow = getLowDigit(rhs);
+ Value rhsHigh = getHighDigit(rhs);
+ Value rhsExt = signExtendArguments ? getSignDigit(rhs) : cst0;
+
+ std::array<Value, 4> lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt};
+ std::array<Value, 4> rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt};
+ std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
+
+ for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) {
+ for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) {
+ if (i + j >= resultDigits.size())
+ continue;
+
+ if (lhsDigit == cst0 || rhsDigit == cst0)
+ continue;
+
+ Value &thisResDigit = resultDigits[i + j];
+ Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit);
+ Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
+ thisResDigit = getLowDigit(current);
+
+ if (i + j + 1 != resultDigits.size()) {
+ Value &nextResDigit = resultDigits[i + j + 1];
+ Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
+ getHighDigit(current));
+ nextResDigit = carry;
+ }
+ }
+ }
+
+ auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
+ Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
+ return rewriter.create<BitwiseOrOp>(loc, low, highBits);
+ };
+ Value low = combineDigits(resultDigits[0], resultDigits[1]);
+ Value high = combineDigits(resultDigits[2], resultDigits[3]);
+
+ return rewriter.create<CompositeConstructOp>(
+ loc, mulOp->getResultTypes().front(), llvm::makeArrayRef({low, high}));
+}
+
//===----------------------------------------------------------------------===//
// Rewrite Patterns
//===----------------------------------------------------------------------===//
-struct ExpandUMulExtendedPattern final : OpRewritePattern<UMulExtendedOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(UMulExtendedOp op,
+template <typename MulExtendedOp, bool SignExtendArguments>
+struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {
+ using OpRewritePattern<MulExtendedOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(MulExtendedOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value lhs = op.getOperand1();
Value rhs = op.getOperand2();
- Type argTy = lhs.getType();
// Currently, WGSL only supports 32-bit integer types. Any other integer
// types should already have been promoted/demoted to i32.
- auto elemTy = getElementTypeOrSelf(argTy).cast<IntegerType>();
+ auto elemTy = getElementTypeOrSelf(lhs.getType()).cast<IntegerType>();
if (elemTy.getIntOrFloatBitWidth() != 32)
return rewriter.notifyMatchFailure(
loc,
llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
- // Emulate 64-bit multiplication by splitting each input element of type i32
- // into 2 16-bit digits of type i32. This is so that the intermediate
- // multiplications and additions do not overflow. We extract these 16-bit
- // digits from i32 vector elements by masking (low digit) and shifting right
- // (high digit).
- //
- // The multiplication algorithm used is the standard (long) multiplication.
- // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit
- // digits. After constant-folding, we end up emitting only 4 multiplications
- // and 4 additions.
- Value cstLowMask = rewriter.create<ConstantOp>(
- loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
- auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
- return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
- };
-
- Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
- getScalarOrSplatAttr(argTy, 16));
- auto getHighDigit = [&rewriter, loc, cst16](Value val) {
- return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
- };
-
- Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(),
- getScalarOrSplatAttr(argTy, 0));
-
- Value lhsLow = getLowDigit(lhs);
- Value lhsHigh = getHighDigit(lhs);
- Value rhsLow = getLowDigit(rhs);
- Value rhsHigh = getHighDigit(rhs);
-
- std::array<Value, 2> lhsDigits = {lhsLow, lhsHigh};
- std::array<Value, 2> rhsDigits = {rhsLow, rhsHigh};
- std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
-
- for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) {
- for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) {
- Value &thisResDigit = resultDigits[i + j];
- Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit);
- Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
- thisResDigit = getLowDigit(current);
-
- if (i + j + 1 != resultDigits.size()) {
- Value &nextResDigit = resultDigits[i + j + 1];
- Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
- getHighDigit(current));
- nextResDigit = carry;
- }
- }
- }
-
- auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
- Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
- return rewriter.create<BitwiseOrOp>(loc, low, highBits);
- };
- Value low = combineDigits(resultDigits[0], resultDigits[1]);
- Value high = combineDigits(resultDigits[2], resultDigits[3]);
-
- rewriter.replaceOpWithNewOp<CompositeConstructOp>(
- op, op.getType(), llvm::makeArrayRef({low, high}));
+ Value mul = lowerExtendedMultiplication(op, rewriter, lhs, rhs,
+ SignExtendArguments);
+ rewriter.replaceOp(op, mul);
return success();
}
};
+using ExpandSMulExtendedPattern =
+ ExpandMulExtendedPattern<SMulExtendedOp, true>;
+using ExpandUMulExtendedPattern =
+ ExpandMulExtendedPattern<UMulExtendedOp, false>;
+
//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
@@ -153,9 +191,8 @@ void populateSPIRVExpandExtendedMultiplicationPatterns(
RewritePatternSet &patterns) {
// WGSL currently does not support extended multiplication ops, see:
// https://github.com/gpuweb/gpuweb/issues/1565.
- // TODO(https://github.com/llvm/llvm-project/issues/59563): Add SMulExtended
- // expansion.
- patterns.add<ExpandUMulExtendedPattern>(patterns.getContext());
+ patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern>(
+ patterns.getContext());
}
} // namespace spirv
} // namespace mlir
diff --git a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
index d0720a3bbd4d..91eeeda6ec54 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
@@ -70,4 +70,79 @@ spirv.func @umul_extended_i16(%arg : i16) -> !spirv.struct<(i16, i16)> "None" {
spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
}
+//===----------------------------------------------------------------------===//
+// spirv.SMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @smul_extended_i32
+// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32)
+// CHECK-DAG: [[CSTMASK:%.+]] = spirv.Constant 65535 : i32
+// CHECK-DAG: [[CST16:%.+]] = spirv.Constant 16 : i32
+// CHECK-NEXT: [[LHSLOW:%.+]] = spirv.BitwiseAnd [[ARG0]], [[CSTMASK]] : i32
+// CHECK-NEXT: [[LHSHI:%.+]] = spirv.ShiftRightLogical [[ARG0]], [[CST16]] : i32
+// CHECK-NEXT: [[LHSSIGN:%.+]] = spirv.ShiftRightArithmetic [[ARG0]], [[CST16]] : i32
+// CHECK-NEXT: [[LHSEXT:%.+]] = spirv.ShiftRightLogical [[LHSSIGN]], [[CST16]] : i32
+// CHECK-NEXT: [[RHSLOW:%.+]] = spirv.BitwiseAnd [[ARG1]], [[CSTMASK]] : i32
+// CHECK-NEXT: [[RHSHI:%.+]] = spirv.ShiftRightLogical [[ARG1]], [[CST16]] : i32
+// CHECK-NEXT: [[RHSSIGN:%.+]] = spirv.ShiftRightArithmetic [[ARG1]], [[CST16]] : i32
+// CHECK-NEXT: [[RHSEXT:%.+]] = spirv.ShiftRightLogical [[RHSSIGN]], [[CST16]] : i32
+// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSLOW]]
+// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSHI]]
+// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSEXT]]
+// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSLOW]]
+// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSHI]]
+// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSEXT]]
+// CHECK-DAG: spirv.IMul [[LHSEXT]], [[RHSLOW]]
+// CHECK-DAG: spirv.IMul [[LHSEXT]], [[RHSHI]]
+// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : i32
+// CHECK: spirv.BitwiseOr
+// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : i32
+// CHECK: spirv.BitwiseOr
+// CHECK: [[RES:%.+]] = spirv.CompositeConstruct [[RESLO:%.+]], [[RESHI:%.+]] : (i32, i32) -> !spirv.struct<(i32, i32)>
+// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(i32, i32)>
+spirv.func @smul_extended_i32(%arg0 : i32, %arg1 : i32) -> !spirv.struct<(i32, i32)> "None" {
+ %0 = spirv.SMulExtended %arg0, %arg1 : !spirv.struct<(i32, i32)>
+ spirv.ReturnValue %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: func @smul_extended_vector_i32
+// CHECK-SAME: ([[ARG0:%.+]]: vector<3xi32>, [[ARG1:%.+]]: vector<3xi32>)
+// CHECK-DAG: [[CSTMASK:%.+]] = spirv.Constant dense<65535> : vector<3xi32>
+// CHECK-DAG: [[CST16:%.+]] = spirv.Constant dense<16> : vector<3xi32>
+// CHECK-NEXT: [[LHSLOW:%.+]] = spirv.BitwiseAnd [[ARG0]], [[CSTMASK]] : vector<3xi32>
+// CHECK-NEXT: [[LHSHI:%.+]] = spirv.ShiftRightLogical [[ARG0]], [[CST16]] : vector<3xi32>
+// CHECK-NEXT: [[LHSSIGN:%.+]] = spirv.ShiftRightArithmetic [[ARG0]], [[CST16]] : vector<3xi32>
+// CHECK-NEXT: [[LHSEXT:%.+]] = spirv.ShiftRightLogical [[LHSSIGN]], [[CST16]] : vector<3xi32>
+// CHECK-NEXT: [[RHSLOW:%.+]] = spirv.BitwiseAnd [[ARG1]], [[CSTMASK]] : vector<3xi32>
+// CHECK-NEXT: [[RHSHI:%.+]] = spirv.ShiftRightLogical [[ARG1]], [[CST16]] : vector<3xi32>
+// CHECK-NEXT: [[RHSSIGN:%.+]] = spirv.ShiftRightArithmetic [[ARG1]], [[CST16]] : vector<3xi32>
+// CHECK-NEXT: [[RHSEXT:%.+]] = spirv.ShiftRightLogical [[RHSSIGN]], [[CST16]] : vector<3xi32>
+// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSLOW]]
+// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSHI]]
+// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSEXT]]
+// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSLOW]]
+// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSHI]]
+// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSEXT]]
+// CHECK-DAG: spirv.IMul [[LHSEXT]], [[RHSLOW]]
+// CHECK-DAG: spirv.IMul [[LHSEXT]], [[RHSHI]]
+// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]]
+// CHECK: spirv.BitwiseOr
+// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]]
+// CHECK: spirv.BitwiseOr
+// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[RESLOW:%.+]], [[RESHI:%.+]]
+// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+spirv.func @smul_extended_vector_i32(%arg0 : vector<3xi32>, %arg1 : vector<3xi32>)
+ -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> "None" {
+ %0 = spirv.SMulExtended %arg0, %arg1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+ spirv.ReturnValue %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+}
+
+// CHECK-LABEL: func @smul_extended_i16
+// CHECK-NEXT: spirv.SMulExtended
+// CHECK-NEXT: spirv.ReturnValue
+spirv.func @smul_extended_i16(%arg : i16) -> !spirv.struct<(i16, i16)> "None" {
+ %0 = spirv.SMulExtended %arg, %arg : !spirv.struct<(i16, i16)>
+ spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
+}
+
} // end module
diff --git a/mlir/test/mlir-vulkan-runner/smul_extended.mlir b/mlir/test/mlir-vulkan-runner/smul_extended.mlir
new file mode 100644
index 000000000000..32ad477645a2
--- /dev/null
+++ b/mlir/test/mlir-vulkan-runner/smul_extended.mlir
@@ -0,0 +1,66 @@
+// Make sure that signed extended multiplication produces expected results
+// with and without expansion to primitive mul/add ops for WebGPU.
+
+// RUN: mlir-vulkan-runner %s \
+// RUN: --shared-libs=%mlir_lib_dir/libvulkan-runtime-wrappers%shlibext,%mlir_lib_dir/libmlir_runner_utils%shlibext \
+// RUN: --entry-point-result=void | FileCheck %s
+
+// RUN: mlir-vulkan-runner %s --vulkan-runner-spirv-webgpu-prepare \
+// RUN: --shared-libs=%mlir_lib_dir/libvulkan-runtime-wrappers%shlibext,%mlir_lib_dir/libmlir_runner_utils%shlibext \
+// RUN: --entry-point-result=void | FileCheck %s
+
+// CHECK: [0, 1, -2, 1, 1048560, -87620295, -131071, 560969770]
+// CHECK: [0, 0, -1, 0, 0, -1, 0, -499807318]
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+} {
+ gpu.module @kernels {
+ gpu.func @kernel_add(%arg0 : memref<8xi32>, %arg1 : memref<8xi32>, %arg2 : memref<8xi32>, %arg3 : memref<8xi32>)
+ kernel attributes { spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
+ %0 = gpu.block_id x
+ %lhs = memref.load %arg0[%0] : memref<8xi32>
+ %rhs = memref.load %arg1[%0] : memref<8xi32>
+ %low, %hi = arith.mulsi_extended %lhs, %rhs : i32
+ memref.store %low, %arg2[%0] : memref<8xi32>
+ memref.store %hi, %arg3[%0] : memref<8xi32>
+ gpu.return
+ }
+ }
+
+ func.func @main() {
+ %buf0 = memref.alloc() : memref<8xi32>
+ %buf1 = memref.alloc() : memref<8xi32>
+ %buf2 = memref.alloc() : memref<8xi32>
+ %buf3 = memref.alloc() : memref<8xi32>
+ %i32_0 = arith.constant 0 : i32
+
+ // Initialize output buffers.
+ %buf4 = memref.cast %buf2 : memref<8xi32> to memref<?xi32>
+ %buf5 = memref.cast %buf3 : memref<8xi32> to memref<?xi32>
+ call @fillResource1DInt(%buf4, %i32_0) : (memref<?xi32>, i32) -> ()
+ call @fillResource1DInt(%buf5, %i32_0) : (memref<?xi32>, i32) -> ()
+
+ %idx_0 = arith.constant 0 : index
+ %idx_1 = arith.constant 1 : index
+ %idx_8 = arith.constant 8 : index
+
+ // Initialize input buffers.
+ %lhs_vals = arith.constant dense<[0, 1, -1, -1, 65535, 65535, -65535, 2088183954]> : vector<8xi32>
+ %rhs_vals = arith.constant dense<[0, 1, 2, -1, 16, -1337, -65535, -1028001427]> : vector<8xi32>
+ vector.store %lhs_vals, %buf0[%idx_0] : memref<8xi32>, vector<8xi32>
+ vector.store %rhs_vals, %buf1[%idx_0] : memref<8xi32>, vector<8xi32>
+
+ gpu.launch_func @kernels::@kernel_add
+ blocks in (%idx_8, %idx_1, %idx_1) threads in (%idx_1, %idx_1, %idx_1)
+ args(%buf0 : memref<8xi32>, %buf1 : memref<8xi32>, %buf2 : memref<8xi32>, %buf3 : memref<8xi32>)
+ %buf_low = memref.cast %buf4 : memref<?xi32> to memref<*xi32>
+ %buf_hi = memref.cast %buf5 : memref<?xi32> to memref<*xi32>
+ call @printMemrefI32(%buf_low) : (memref<*xi32>) -> ()
+ call @printMemrefI32(%buf_hi) : (memref<*xi32>) -> ()
+ return
+ }
+ func.func private @fillResource1DInt(%0 : memref<?xi32>, %1 : i32)
+ func.func private @printMemrefI32(%ptr : memref<*xi32>)
+}
More information about the Mlir-commits
mailing list