[Mlir-commits] [mlir] [mlir][spirv][webgpu] Add lowering of IAddCarry to IAdd (PR #68495)
Finn Plummer
llvmlistbot at llvm.org
Tue Oct 10 09:44:35 PDT 2023
https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/68495
>From c5d0ceccb2726c1dbd050c2a5ec772a723067658 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Sat, 7 Oct 2023 21:48:44 +0200
Subject: [PATCH 1/2] [mlir][spirv][webgpu] Add lowering of IAddCarry to IAdd
WebGPU does not currently support extended arithmetic, this is an issue
when we want to lower from SPIR-V. This commit adds a pattern to
transform and emulate spirv.IAddCarry with spirv.IAdd operations
Fixes #65154
---
.../Transforms/SPIRVWebGPUTransforms.cpp | 74 ++++++++++++++++++-
.../SPIRV/Transforms/webgpu-prepare.mlir | 58 +++++++++++++++
2 files changed, 130 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
index 44fea86785593e9..80e6d93f91353df 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
@@ -133,6 +133,48 @@ Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter,
loc, mulOp->getResultTypes().front(), llvm::ArrayRef({low, high}));
}
+Value lowerCarryAddition(Operation *addOp, PatternRewriter &rewriter, Value lhs,
+ Value rhs) {
+ Location loc = addOp->getLoc();
+ Type argTy = lhs.getType();
+ // Emulate 64-bit addition by splitting each input element of type i32 to
+ // i16 similar to above in lowerExtendedMultiplication. We then expand
+ // to 3 additions:
+ // - Add two low digits into low resut
+ // - Add two high digits into high result
+ // - Add the carry from low result to high result
+ Value cstLowMask = rewriter.create<ConstantOp>(
+ loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
+ auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
+ return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
+ };
+
+ Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
+ getScalarOrSplatAttr(argTy, 16));
+ auto getHighDigit = [&rewriter, loc, cst16](Value val) {
+ return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
+ };
+
+ Value lhsLow = getLowDigit(lhs);
+ Value lhsHigh = getHighDigit(lhs);
+ Value rhsLow = getLowDigit(rhs);
+ Value rhsHigh = getHighDigit(rhs);
+
+ Value low = rewriter.create<IAddOp>(loc, lhsLow, rhsLow);
+ Value high = rewriter.create<IAddOp>(loc, lhsHigh, rhsHigh);
+ Value highWithCarry = rewriter.create<IAddOp>(loc, high, getHighDigit(low));
+
+ auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
+ Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
+ return rewriter.create<BitwiseOrOp>(loc, low, highBits);
+ };
+ Value out = combineDigits(getLowDigit(highWithCarry), getLowDigit(low));
+ Value carry = getHighDigit(highWithCarry);
+
+ return rewriter.create<CompositeConstructOp>(
+ loc, addOp->getResultTypes().front(), llvm::ArrayRef({out, carry}));
+}
+
//===----------------------------------------------------------------------===//
// Rewrite Patterns
//===----------------------------------------------------------------------===//
@@ -167,6 +209,30 @@ using ExpandSMulExtendedPattern =
using ExpandUMulExtendedPattern =
ExpandMulExtendedPattern<UMulExtendedOp, false>;
+struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
+ using OpRewritePattern<IAddCarryOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IAddCarryOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ Value lhs = op.getOperand1();
+ Value rhs = op.getOperand2();
+
+ // Currently, WGSL only supports 32-bit integer types. Any other integer
+ // types should already have been promoted/demoted to i32.
+ auto elemTy = cast<IntegerType>(getElementTypeOrSelf(lhs.getType()));
+ if (elemTy.getIntOrFloatBitWidth() != 32)
+ return rewriter.notifyMatchFailure(
+ loc,
+ llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
+
+ Value add = lowerCarryAddition(op, rewriter, lhs, rhs);
+
+ rewriter.replaceOp(op, add);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
@@ -191,8 +257,12 @@ void populateSPIRVExpandExtendedMultiplicationPatterns(
RewritePatternSet &patterns) {
// WGSL currently does not support extended multiplication ops, see:
// https://github.com/gpuweb/gpuweb/issues/1565.
- patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern>(
- patterns.getContext());
+ patterns.add<
+ // clang-format off
+ ExpandSMulExtendedPattern,
+ ExpandUMulExtendedPattern,
+ ExpandAddCarryPattern
+ >(patterns.getContext());
}
} // namespace spirv
} // namespace mlir
diff --git a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
index 91eeeda6ec54c64..dbf23cddffab6dd 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
@@ -145,4 +145,62 @@ spirv.func @smul_extended_i16(%arg : i16) -> !spirv.struct<(i16, i16)> "None" {
spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
}
+// CHECK-LABEL: func @iaddcarry_i32
+// CHECK-SAME: ([[A:%.+]]: i32, [[B:%.+]]: i32)
+// CHECK-NEXT: [[CSTMASK:%.+]] = spirv.Constant 65535 : i32
+// CHECK-NEXT: [[CST16:%.+]] = spirv.Constant 16 : i32
+// CHECK-NEXT: [[LHSLOW:%.+]] = spirv.BitwiseAnd [[A]], [[CSTMASK]] : i32
+// CHECK-NEXT: [[LHSHI:%.+]] = spirv.ShiftRightLogical [[A]], [[CST16]] : i32
+// CHECK-DAG: [[RHSLOW:%.+]] = spirv.BitwiseAnd [[B]], [[CSTMASK]] : i32
+// CHECK-DAG: [[RHSHI:%.+]] = spirv.ShiftRightLogical [[B]], [[CST16]] : i32
+// CHECK-DAG: [[LOW:%.+]] = spirv.IAdd [[LHSLOW]], [[RHSLOW]] : i32
+// CHECK-DAG: [[HI:%.+]] = spirv.IAdd [[LHSHI]], [[RHSHI]]
+// CHECK-DAG: [[LOWCRY:%.+]] = spirv.ShiftRightLogical [[LOW]], [[CST16]] : i32
+// CHECK-DAG: [[HI_TTL:%.+]] = spirv.IAdd [[HI]], [[LOWCRY]]
+// CHECK-DAG: spirv.ShiftRightLogical
+// CHECK-DAG: spirv.BitwiseAnd
+// CHECK-DAG: spirv.BitwiseAnd
+// CHECK-DAG: spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : i32
+// CHECK-DAG: spirv.BitwiseOr
+// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[RESLO:%.+]], [[RESHI:%.+]] : (i32, i32) -> !spirv.struct<(i32, i32)>
+// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(i32, i32)>
+spirv.func @iaddcarry_i32(%a : i32, %b : i32) -> !spirv.struct<(i32, i32)> "None" {
+ %0 = spirv.IAddCarry %a, %b : !spirv.struct<(i32, i32)>
+ spirv.ReturnValue %0 : !spirv.struct<(i32, i32)>
+}
+
+
+// CHECK-LABEL: func @iaddcarry_vector_i32
+// CHECK-SAME: ([[A:%.+]]: vector<3xi32>, [[B:%.+]]: vector<3xi32>)
+// CHECK-NEXT: [[CSTMASK:%.+]] = spirv.Constant dense<65535> : vector<3xi32>
+// CHECK-NEXT: [[CST16:%.+]] = spirv.Constant dense<16> : vector<3xi32>
+// CHECK-NEXT: [[LHSLOW:%.+]] = spirv.BitwiseAnd [[A]], [[CSTMASK]] : vector<3xi32>
+// CHECK-NEXT: [[LHSHI:%.+]] = spirv.ShiftRightLogical [[A]], [[CST16]] : vector<3xi32>
+// CHECK-DAG: [[RHSLOW:%.+]] = spirv.BitwiseAnd [[B]], [[CSTMASK]] : vector<3xi32>
+// CHECK-DAG: [[RHSHI:%.+]] = spirv.ShiftRightLogical [[B]], [[CST16]] : vector<3xi32>
+// CHECK-DAG: [[LOW:%.+]] = spirv.IAdd [[LHSLOW]], [[RHSLOW]] : vector<3xi32>
+// CHECK-DAG: [[HI:%.+]] = spirv.IAdd [[LHSHI]], [[RHSHI]]
+// CHECK-DAG: [[LOWCRY:%.+]] = spirv.ShiftRightLogical [[LOW]], [[CST16]] : vector<3xi32>
+// CHECK-DAG: [[HI_TTL:%.+]] = spirv.IAdd [[HI]], [[LOWCRY]]
+// CHECK-DAG: spirv.ShiftRightLogical
+// CHECK-DAG: spirv.BitwiseAnd
+// CHECK-DAG: spirv.BitwiseAnd
+// CHECK-DAG: spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : vector<3xi32>
+// CHECK-DAG: spirv.BitwiseOr
+// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[RESLO:%.+]], [[RESHI:%.+]] : (vector<3xi32>, vector<3xi32>) -> !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+spirv.func @iaddcarry_vector_i32(%a : vector<3xi32>, %b : vector<3xi32>)
+ -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> "None" {
+ %0 = spirv.IAddCarry %a, %b : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+ spirv.ReturnValue %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+}
+
+// CHECK-LABEL: func @iaddcarry_i16
+// CHECK-NEXT: spirv.IAddCarry
+// CHECK-NEXT: spirv.ReturnValue
+spirv.func @iaddcarry_i16(%a : i16, %b : i16) -> !spirv.struct<(i16, i16)> "None" {
+ %0 = spirv.IAddCarry %a, %b : !spirv.struct<(i16, i16)>
+ spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
+}
+
} // end module
>From eeccf0a1b6d5f9a5815223d538743bc82d66f212 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Tue, 10 Oct 2023 16:56:44 +0200
Subject: [PATCH 2/2] review comments:
- inline the lowerAddCarry function to make it clearer that we check
that each integer is of i32
- switch the computation of the carry to reduce instructions and
simplify
---
.../Transforms/SPIRVWebGPUTransforms.cpp | 59 +++++--------------
.../SPIRV/Transforms/webgpu-prepare.mlir | 45 ++++----------
2 files changed, 27 insertions(+), 77 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
index 80e6d93f91353df..9a780608f1ebbde 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
@@ -133,48 +133,6 @@ Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter,
loc, mulOp->getResultTypes().front(), llvm::ArrayRef({low, high}));
}
-Value lowerCarryAddition(Operation *addOp, PatternRewriter &rewriter, Value lhs,
- Value rhs) {
- Location loc = addOp->getLoc();
- Type argTy = lhs.getType();
- // Emulate 64-bit addition by splitting each input element of type i32 to
- // i16 similar to above in lowerExtendedMultiplication. We then expand
- // to 3 additions:
- // - Add two low digits into low resut
- // - Add two high digits into high result
- // - Add the carry from low result to high result
- Value cstLowMask = rewriter.create<ConstantOp>(
- loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
- auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
- return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
- };
-
- Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
- getScalarOrSplatAttr(argTy, 16));
- auto getHighDigit = [&rewriter, loc, cst16](Value val) {
- return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
- };
-
- Value lhsLow = getLowDigit(lhs);
- Value lhsHigh = getHighDigit(lhs);
- Value rhsLow = getLowDigit(rhs);
- Value rhsHigh = getHighDigit(rhs);
-
- Value low = rewriter.create<IAddOp>(loc, lhsLow, rhsLow);
- Value high = rewriter.create<IAddOp>(loc, lhsHigh, rhsHigh);
- Value highWithCarry = rewriter.create<IAddOp>(loc, high, getHighDigit(low));
-
- auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
- Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
- return rewriter.create<BitwiseOrOp>(loc, low, highBits);
- };
- Value out = combineDigits(getLowDigit(highWithCarry), getLowDigit(low));
- Value carry = getHighDigit(highWithCarry);
-
- return rewriter.create<CompositeConstructOp>(
- loc, addOp->getResultTypes().front(), llvm::ArrayRef({out, carry}));
-}
-
//===----------------------------------------------------------------------===//
// Rewrite Patterns
//===----------------------------------------------------------------------===//
@@ -220,13 +178,26 @@ struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
// Currently, WGSL only supports 32-bit integer types. Any other integer
// types should already have been promoted/demoted to i32.
- auto elemTy = cast<IntegerType>(getElementTypeOrSelf(lhs.getType()));
+ Type argTy = lhs.getType();
+ auto elemTy = cast<IntegerType>(getElementTypeOrSelf(argTy));
if (elemTy.getIntOrFloatBitWidth() != 32)
return rewriter.notifyMatchFailure(
loc,
llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
- Value add = lowerCarryAddition(op, rewriter, lhs, rhs);
+ Value one =
+ rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 1));
+ Value zero =
+ rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 0));
+
+ // Emulate 64-bit unsigned addition by allowing our addition to overflow,
+ // and then set the carry accordingly.
+ Value out = rewriter.create<IAddOp>(loc, lhs, rhs);
+ Value cmp = rewriter.create<ULessThanOp>(loc, out, lhs);
+ Value carry = rewriter.create<SelectOp>(loc, cmp, one, zero);
+
+ Value add = rewriter.create<CompositeConstructOp>(
+ loc, op->getResultTypes().front(), llvm::ArrayRef({out, carry}));
rewriter.replaceOp(op, add);
return success();
diff --git a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
index dbf23cddffab6dd..1ec4e5e4f9664b8 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
@@ -147,47 +147,26 @@ spirv.func @smul_extended_i16(%arg : i16) -> !spirv.struct<(i16, i16)> "None" {
// CHECK-LABEL: func @iaddcarry_i32
// CHECK-SAME: ([[A:%.+]]: i32, [[B:%.+]]: i32)
-// CHECK-NEXT: [[CSTMASK:%.+]] = spirv.Constant 65535 : i32
-// CHECK-NEXT: [[CST16:%.+]] = spirv.Constant 16 : i32
-// CHECK-NEXT: [[LHSLOW:%.+]] = spirv.BitwiseAnd [[A]], [[CSTMASK]] : i32
-// CHECK-NEXT: [[LHSHI:%.+]] = spirv.ShiftRightLogical [[A]], [[CST16]] : i32
-// CHECK-DAG: [[RHSLOW:%.+]] = spirv.BitwiseAnd [[B]], [[CSTMASK]] : i32
-// CHECK-DAG: [[RHSHI:%.+]] = spirv.ShiftRightLogical [[B]], [[CST16]] : i32
-// CHECK-DAG: [[LOW:%.+]] = spirv.IAdd [[LHSLOW]], [[RHSLOW]] : i32
-// CHECK-DAG: [[HI:%.+]] = spirv.IAdd [[LHSHI]], [[RHSHI]]
-// CHECK-DAG: [[LOWCRY:%.+]] = spirv.ShiftRightLogical [[LOW]], [[CST16]] : i32
-// CHECK-DAG: [[HI_TTL:%.+]] = spirv.IAdd [[HI]], [[LOWCRY]]
-// CHECK-DAG: spirv.ShiftRightLogical
-// CHECK-DAG: spirv.BitwiseAnd
-// CHECK-DAG: spirv.BitwiseAnd
-// CHECK-DAG: spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : i32
-// CHECK-DAG: spirv.BitwiseOr
-// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[RESLO:%.+]], [[RESHI:%.+]] : (i32, i32) -> !spirv.struct<(i32, i32)>
+// CHECK-NEXT: [[ONE:%.+]] = spirv.Constant 1 : i32
+// CHECK-NEXT: [[ZERO:%.+]] = spirv.Constant 0 : i32
+// CHECK-NEXT: [[OUT:%.+]] = spirv.IAdd [[A]], [[B]]
+// CHECK-NEXT: [[CMP:%.+]] = spirv.ULessThan [[OUT]], [[A]]
+// CHECK-NEXT: [[CARRY:%.+]] = spirv.Select [[CMP]], [[ONE]], [[ZERO]]
+// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[OUT]], [[CARRY]] : (i32, i32) -> !spirv.struct<(i32, i32)>
// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(i32, i32)>
spirv.func @iaddcarry_i32(%a : i32, %b : i32) -> !spirv.struct<(i32, i32)> "None" {
%0 = spirv.IAddCarry %a, %b : !spirv.struct<(i32, i32)>
spirv.ReturnValue %0 : !spirv.struct<(i32, i32)>
}
-
// CHECK-LABEL: func @iaddcarry_vector_i32
// CHECK-SAME: ([[A:%.+]]: vector<3xi32>, [[B:%.+]]: vector<3xi32>)
-// CHECK-NEXT: [[CSTMASK:%.+]] = spirv.Constant dense<65535> : vector<3xi32>
-// CHECK-NEXT: [[CST16:%.+]] = spirv.Constant dense<16> : vector<3xi32>
-// CHECK-NEXT: [[LHSLOW:%.+]] = spirv.BitwiseAnd [[A]], [[CSTMASK]] : vector<3xi32>
-// CHECK-NEXT: [[LHSHI:%.+]] = spirv.ShiftRightLogical [[A]], [[CST16]] : vector<3xi32>
-// CHECK-DAG: [[RHSLOW:%.+]] = spirv.BitwiseAnd [[B]], [[CSTMASK]] : vector<3xi32>
-// CHECK-DAG: [[RHSHI:%.+]] = spirv.ShiftRightLogical [[B]], [[CST16]] : vector<3xi32>
-// CHECK-DAG: [[LOW:%.+]] = spirv.IAdd [[LHSLOW]], [[RHSLOW]] : vector<3xi32>
-// CHECK-DAG: [[HI:%.+]] = spirv.IAdd [[LHSHI]], [[RHSHI]]
-// CHECK-DAG: [[LOWCRY:%.+]] = spirv.ShiftRightLogical [[LOW]], [[CST16]] : vector<3xi32>
-// CHECK-DAG: [[HI_TTL:%.+]] = spirv.IAdd [[HI]], [[LOWCRY]]
-// CHECK-DAG: spirv.ShiftRightLogical
-// CHECK-DAG: spirv.BitwiseAnd
-// CHECK-DAG: spirv.BitwiseAnd
-// CHECK-DAG: spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : vector<3xi32>
-// CHECK-DAG: spirv.BitwiseOr
-// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[RESLO:%.+]], [[RESHI:%.+]] : (vector<3xi32>, vector<3xi32>) -> !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+// CHECK-NEXT: [[ONE:%.+]] = spirv.Constant dense<1> : vector<3xi32>
+// CHECK-NEXT: [[ZERO:%.+]] = spirv.Constant dense<0> : vector<3xi32>
+// CHECK-NEXT: [[OUT:%.+]] = spirv.IAdd [[A]], [[B]]
+// CHECK-NEXT: [[CMP:%.+]] = spirv.ULessThan [[OUT]], [[A]]
+// CHECK-NEXT: [[CARRY:%.+]] = spirv.Select [[CMP]], [[ONE]], [[ZERO]]
+// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[OUT]], [[CARRY]] : (vector<3xi32>, vector<3xi32>) -> !spirv.struct<(vector<3xi32>, vector<3xi32>)>
// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
spirv.func @iaddcarry_vector_i32(%a : vector<3xi32>, %b : vector<3xi32>)
-> !spirv.struct<(vector<3xi32>, vector<3xi32>)> "None" {
More information about the Mlir-commits
mailing list