[Mlir-commits] [mlir] [mlir][spirv][webgpu] Add lowering of IAddCarry to IAdd (PR #68495)

Finn Plummer llvmlistbot at llvm.org
Thu Oct 12 11:27:10 PDT 2023


https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/68495

>From 8bbc9f20bced868c04138120ceed0ddf555bec43 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Sat, 7 Oct 2023 21:48:44 +0200
Subject: [PATCH 1/3] [mlir][spirv][webgpu] Add lowering of IAddCarry to IAdd

WebGPU does not currently support extended arithmetic, this is an issue
when we want to lower from SPIR-V. This commit adds a pattern to
transform and emulate spirv.IAddCarry with spirv.IAdd operations

Fixes #65154
---
 .../Transforms/SPIRVWebGPUTransforms.cpp      | 74 ++++++++++++++++++-
 .../SPIRV/Transforms/webgpu-prepare.mlir      | 58 +++++++++++++++
 2 files changed, 130 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
index 44fea86785593e9..80e6d93f91353df 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
@@ -133,6 +133,48 @@ Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter,
       loc, mulOp->getResultTypes().front(), llvm::ArrayRef({low, high}));
 }
 
+Value lowerCarryAddition(Operation *addOp, PatternRewriter &rewriter, Value lhs,
+                         Value rhs) {
+  Location loc = addOp->getLoc();
+  Type argTy = lhs.getType();
+  // Emulate 64-bit addition by splitting each input element of type i32 to
+  // i16 similar to above in lowerExtendedMultiplication. We then expand
+  // to 3 additions:
+  //    - Add two low digits into low resut
+  //    - Add two high digits into high result
+  //    - Add the carry from low result to high result
+  Value cstLowMask = rewriter.create<ConstantOp>(
+      loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
+  auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
+    return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
+  };
+
+  Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
+                                            getScalarOrSplatAttr(argTy, 16));
+  auto getHighDigit = [&rewriter, loc, cst16](Value val) {
+    return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
+  };
+
+  Value lhsLow = getLowDigit(lhs);
+  Value lhsHigh = getHighDigit(lhs);
+  Value rhsLow = getLowDigit(rhs);
+  Value rhsHigh = getHighDigit(rhs);
+
+  Value low = rewriter.create<IAddOp>(loc, lhsLow, rhsLow);
+  Value high = rewriter.create<IAddOp>(loc, lhsHigh, rhsHigh);
+  Value highWithCarry = rewriter.create<IAddOp>(loc, high, getHighDigit(low));
+
+  auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
+    Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
+    return rewriter.create<BitwiseOrOp>(loc, low, highBits);
+  };
+  Value out = combineDigits(getLowDigit(highWithCarry), getLowDigit(low));
+  Value carry = getHighDigit(highWithCarry);
+
+  return rewriter.create<CompositeConstructOp>(
+      loc, addOp->getResultTypes().front(), llvm::ArrayRef({out, carry}));
+}
+
 //===----------------------------------------------------------------------===//
 // Rewrite Patterns
 //===----------------------------------------------------------------------===//
@@ -167,6 +209,30 @@ using ExpandSMulExtendedPattern =
 using ExpandUMulExtendedPattern =
     ExpandMulExtendedPattern<UMulExtendedOp, false>;
 
+struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
+  using OpRewritePattern<IAddCarryOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IAddCarryOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    Value lhs = op.getOperand1();
+    Value rhs = op.getOperand2();
+
+    // Currently, WGSL only supports 32-bit integer types. Any other integer
+    // types should already have been promoted/demoted to i32.
+    auto elemTy = cast<IntegerType>(getElementTypeOrSelf(lhs.getType()));
+    if (elemTy.getIntOrFloatBitWidth() != 32)
+      return rewriter.notifyMatchFailure(
+          loc,
+          llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
+
+    Value add = lowerCarryAddition(op, rewriter, lhs, rhs);
+
+    rewriter.replaceOp(op, add);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Passes
 //===----------------------------------------------------------------------===//
@@ -191,8 +257,12 @@ void populateSPIRVExpandExtendedMultiplicationPatterns(
     RewritePatternSet &patterns) {
   // WGSL currently does not support extended multiplication ops, see:
   // https://github.com/gpuweb/gpuweb/issues/1565.
-  patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern>(
-      patterns.getContext());
+  patterns.add<
+      // clang-format off
+    ExpandSMulExtendedPattern,
+    ExpandUMulExtendedPattern,
+    ExpandAddCarryPattern
+  >(patterns.getContext());
 }
 } // namespace spirv
 } // namespace mlir
diff --git a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
index 91eeeda6ec54c64..dbf23cddffab6dd 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
@@ -145,4 +145,62 @@ spirv.func @smul_extended_i16(%arg : i16) -> !spirv.struct<(i16, i16)> "None" {
   spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
 }
 
+// CHECK-LABEL: func @iaddcarry_i32
+// CHECK-SAME:       ([[A:%.+]]: i32, [[B:%.+]]: i32)
+// CHECK-NEXT:       [[CSTMASK:%.+]] = spirv.Constant 65535 : i32
+// CHECK-NEXT:       [[CST16:%.+]]   = spirv.Constant 16 : i32
+// CHECK-NEXT:       [[LHSLOW:%.+]]  = spirv.BitwiseAnd [[A]], [[CSTMASK]] : i32
+// CHECK-NEXT:       [[LHSHI:%.+]]   = spirv.ShiftRightLogical [[A]], [[CST16]] : i32
+// CHECK-DAG:        [[RHSLOW:%.+]]  = spirv.BitwiseAnd [[B]], [[CSTMASK]] : i32
+// CHECK-DAG:        [[RHSHI:%.+]]   = spirv.ShiftRightLogical [[B]], [[CST16]] : i32
+// CHECK-DAG:        [[LOW:%.+]]     = spirv.IAdd [[LHSLOW]], [[RHSLOW]] : i32
+// CHECK-DAG:        [[HI:%.+]]      = spirv.IAdd [[LHSHI]], [[RHSHI]]
+// CHECK-DAG:        [[LOWCRY:%.+]]  = spirv.ShiftRightLogical [[LOW]], [[CST16]] : i32
+// CHECK-DAG:        [[HI_TTL:%.+]]  = spirv.IAdd [[HI]], [[LOWCRY]]
+// CHECK-DAG:                          spirv.ShiftRightLogical
+// CHECK-DAG:                          spirv.BitwiseAnd
+// CHECK-DAG:                          spirv.BitwiseAnd
+// CHECK-DAG:                          spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : i32
+// CHECK-DAG:                          spirv.BitwiseOr
+// CHECK-NEXT:       [[RES:%.+]]     = spirv.CompositeConstruct [[RESLO:%.+]], [[RESHI:%.+]] : (i32, i32) -> !spirv.struct<(i32, i32)>
+// CHECK-NEXT:       spirv.ReturnValue [[RES]] : !spirv.struct<(i32, i32)>
+spirv.func @iaddcarry_i32(%a : i32, %b : i32) -> !spirv.struct<(i32, i32)> "None" {
+  %0 = spirv.IAddCarry %a, %b : !spirv.struct<(i32, i32)>
+  spirv.ReturnValue %0 : !spirv.struct<(i32, i32)>
+}
+
+
+// CHECK-LABEL: func @iaddcarry_vector_i32
+// CHECK-SAME:       ([[A:%.+]]: vector<3xi32>, [[B:%.+]]: vector<3xi32>)
+// CHECK-NEXT:       [[CSTMASK:%.+]] = spirv.Constant dense<65535> : vector<3xi32>
+// CHECK-NEXT:       [[CST16:%.+]]   = spirv.Constant dense<16> : vector<3xi32>
+// CHECK-NEXT:       [[LHSLOW:%.+]]  = spirv.BitwiseAnd [[A]], [[CSTMASK]] : vector<3xi32>
+// CHECK-NEXT:       [[LHSHI:%.+]]   = spirv.ShiftRightLogical [[A]], [[CST16]] : vector<3xi32>
+// CHECK-DAG:        [[RHSLOW:%.+]]  = spirv.BitwiseAnd [[B]], [[CSTMASK]] : vector<3xi32>
+// CHECK-DAG:        [[RHSHI:%.+]]   = spirv.ShiftRightLogical [[B]], [[CST16]] : vector<3xi32>
+// CHECK-DAG:        [[LOW:%.+]]     = spirv.IAdd [[LHSLOW]], [[RHSLOW]] : vector<3xi32>
+// CHECK-DAG:        [[HI:%.+]]      = spirv.IAdd [[LHSHI]], [[RHSHI]]
+// CHECK-DAG:        [[LOWCRY:%.+]]  = spirv.ShiftRightLogical [[LOW]], [[CST16]] : vector<3xi32>
+// CHECK-DAG:        [[HI_TTL:%.+]]  = spirv.IAdd [[HI]], [[LOWCRY]]
+// CHECK-DAG:                          spirv.ShiftRightLogical
+// CHECK-DAG:                          spirv.BitwiseAnd
+// CHECK-DAG:                          spirv.BitwiseAnd
+// CHECK-DAG:                          spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : vector<3xi32>
+// CHECK-DAG:                          spirv.BitwiseOr
+// CHECK-NEXT:       [[RES:%.+]]     = spirv.CompositeConstruct [[RESLO:%.+]], [[RESHI:%.+]] : (vector<3xi32>, vector<3xi32>) -> !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+// CHECK-NEXT:       spirv.ReturnValue [[RES]] : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+spirv.func @iaddcarry_vector_i32(%a : vector<3xi32>, %b : vector<3xi32>)
+  -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> "None" {
+  %0 = spirv.IAddCarry %a, %b : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+  spirv.ReturnValue %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+}
+
+// CHECK-LABEL: func @iaddcarry_i16
+// CHECK-NEXT:       spirv.IAddCarry
+// CHECK-NEXT:       spirv.ReturnValue
+spirv.func @iaddcarry_i16(%a : i16, %b : i16) -> !spirv.struct<(i16, i16)> "None" {
+  %0 = spirv.IAddCarry %a, %b : !spirv.struct<(i16, i16)>
+  spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
+}
+
 } // end module

>From 5235cac099107b627dcb498fe8e7d5108a3dcfd5 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Tue, 10 Oct 2023 16:56:44 +0200
Subject: [PATCH 2/3] review comments:

- inline the lowerAddCarry function to make it clearer that we check
  that each integer is of i32

- switch the computation of the carry to reduce instructions and
  simplify
---
 .../Transforms/SPIRVWebGPUTransforms.cpp      | 59 +++++--------------
 .../SPIRV/Transforms/webgpu-prepare.mlir      | 45 ++++----------
 2 files changed, 27 insertions(+), 77 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
index 80e6d93f91353df..9a780608f1ebbde 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
@@ -133,48 +133,6 @@ Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter,
       loc, mulOp->getResultTypes().front(), llvm::ArrayRef({low, high}));
 }
 
-Value lowerCarryAddition(Operation *addOp, PatternRewriter &rewriter, Value lhs,
-                         Value rhs) {
-  Location loc = addOp->getLoc();
-  Type argTy = lhs.getType();
-  // Emulate 64-bit addition by splitting each input element of type i32 to
-  // i16 similar to above in lowerExtendedMultiplication. We then expand
-  // to 3 additions:
-  //    - Add two low digits into low resut
-  //    - Add two high digits into high result
-  //    - Add the carry from low result to high result
-  Value cstLowMask = rewriter.create<ConstantOp>(
-      loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
-  auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
-    return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
-  };
-
-  Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
-                                            getScalarOrSplatAttr(argTy, 16));
-  auto getHighDigit = [&rewriter, loc, cst16](Value val) {
-    return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
-  };
-
-  Value lhsLow = getLowDigit(lhs);
-  Value lhsHigh = getHighDigit(lhs);
-  Value rhsLow = getLowDigit(rhs);
-  Value rhsHigh = getHighDigit(rhs);
-
-  Value low = rewriter.create<IAddOp>(loc, lhsLow, rhsLow);
-  Value high = rewriter.create<IAddOp>(loc, lhsHigh, rhsHigh);
-  Value highWithCarry = rewriter.create<IAddOp>(loc, high, getHighDigit(low));
-
-  auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
-    Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
-    return rewriter.create<BitwiseOrOp>(loc, low, highBits);
-  };
-  Value out = combineDigits(getLowDigit(highWithCarry), getLowDigit(low));
-  Value carry = getHighDigit(highWithCarry);
-
-  return rewriter.create<CompositeConstructOp>(
-      loc, addOp->getResultTypes().front(), llvm::ArrayRef({out, carry}));
-}
-
 //===----------------------------------------------------------------------===//
 // Rewrite Patterns
 //===----------------------------------------------------------------------===//
@@ -220,13 +178,26 @@ struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
 
     // Currently, WGSL only supports 32-bit integer types. Any other integer
     // types should already have been promoted/demoted to i32.
-    auto elemTy = cast<IntegerType>(getElementTypeOrSelf(lhs.getType()));
+    Type argTy = lhs.getType();
+    auto elemTy = cast<IntegerType>(getElementTypeOrSelf(argTy));
     if (elemTy.getIntOrFloatBitWidth() != 32)
       return rewriter.notifyMatchFailure(
           loc,
           llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
 
-    Value add = lowerCarryAddition(op, rewriter, lhs, rhs);
+    Value one =
+        rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 1));
+    Value zero =
+        rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 0));
+
+    // Emulate 64-bit unsigned addition by allowing our addition to overflow,
+    // and then set the carry accordingly.
+    Value out = rewriter.create<IAddOp>(loc, lhs, rhs);
+    Value cmp = rewriter.create<ULessThanOp>(loc, out, lhs);
+    Value carry = rewriter.create<SelectOp>(loc, cmp, one, zero);
+
+    Value add = rewriter.create<CompositeConstructOp>(
+        loc, op->getResultTypes().front(), llvm::ArrayRef({out, carry}));
 
     rewriter.replaceOp(op, add);
     return success();
diff --git a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
index dbf23cddffab6dd..1ec4e5e4f9664b8 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
@@ -147,47 +147,26 @@ spirv.func @smul_extended_i16(%arg : i16) -> !spirv.struct<(i16, i16)> "None" {
 
 // CHECK-LABEL: func @iaddcarry_i32
 // CHECK-SAME:       ([[A:%.+]]: i32, [[B:%.+]]: i32)
-// CHECK-NEXT:       [[CSTMASK:%.+]] = spirv.Constant 65535 : i32
-// CHECK-NEXT:       [[CST16:%.+]]   = spirv.Constant 16 : i32
-// CHECK-NEXT:       [[LHSLOW:%.+]]  = spirv.BitwiseAnd [[A]], [[CSTMASK]] : i32
-// CHECK-NEXT:       [[LHSHI:%.+]]   = spirv.ShiftRightLogical [[A]], [[CST16]] : i32
-// CHECK-DAG:        [[RHSLOW:%.+]]  = spirv.BitwiseAnd [[B]], [[CSTMASK]] : i32
-// CHECK-DAG:        [[RHSHI:%.+]]   = spirv.ShiftRightLogical [[B]], [[CST16]] : i32
-// CHECK-DAG:        [[LOW:%.+]]     = spirv.IAdd [[LHSLOW]], [[RHSLOW]] : i32
-// CHECK-DAG:        [[HI:%.+]]      = spirv.IAdd [[LHSHI]], [[RHSHI]]
-// CHECK-DAG:        [[LOWCRY:%.+]]  = spirv.ShiftRightLogical [[LOW]], [[CST16]] : i32
-// CHECK-DAG:        [[HI_TTL:%.+]]  = spirv.IAdd [[HI]], [[LOWCRY]]
-// CHECK-DAG:                          spirv.ShiftRightLogical
-// CHECK-DAG:                          spirv.BitwiseAnd
-// CHECK-DAG:                          spirv.BitwiseAnd
-// CHECK-DAG:                          spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : i32
-// CHECK-DAG:                          spirv.BitwiseOr
-// CHECK-NEXT:       [[RES:%.+]]     = spirv.CompositeConstruct [[RESLO:%.+]], [[RESHI:%.+]] : (i32, i32) -> !spirv.struct<(i32, i32)>
+// CHECK-NEXT:       [[ONE:%.+]]    = spirv.Constant 1 : i32
+// CHECK-NEXT:       [[ZERO:%.+]]   = spirv.Constant 0 : i32
+// CHECK-NEXT:       [[OUT:%.+]]    = spirv.IAdd [[A]], [[B]]
+// CHECK-NEXT:       [[CMP:%.+]]    = spirv.ULessThan [[OUT]], [[A]]
+// CHECK-NEXT:       [[CARRY:%.+]]  = spirv.Select [[CMP]], [[ONE]], [[ZERO]]
+// CHECK-NEXT:       [[RES:%.+]]     = spirv.CompositeConstruct [[OUT]], [[CARRY]] : (i32, i32) -> !spirv.struct<(i32, i32)>
 // CHECK-NEXT:       spirv.ReturnValue [[RES]] : !spirv.struct<(i32, i32)>
 spirv.func @iaddcarry_i32(%a : i32, %b : i32) -> !spirv.struct<(i32, i32)> "None" {
   %0 = spirv.IAddCarry %a, %b : !spirv.struct<(i32, i32)>
   spirv.ReturnValue %0 : !spirv.struct<(i32, i32)>
 }
 
-
 // CHECK-LABEL: func @iaddcarry_vector_i32
 // CHECK-SAME:       ([[A:%.+]]: vector<3xi32>, [[B:%.+]]: vector<3xi32>)
-// CHECK-NEXT:       [[CSTMASK:%.+]] = spirv.Constant dense<65535> : vector<3xi32>
-// CHECK-NEXT:       [[CST16:%.+]]   = spirv.Constant dense<16> : vector<3xi32>
-// CHECK-NEXT:       [[LHSLOW:%.+]]  = spirv.BitwiseAnd [[A]], [[CSTMASK]] : vector<3xi32>
-// CHECK-NEXT:       [[LHSHI:%.+]]   = spirv.ShiftRightLogical [[A]], [[CST16]] : vector<3xi32>
-// CHECK-DAG:        [[RHSLOW:%.+]]  = spirv.BitwiseAnd [[B]], [[CSTMASK]] : vector<3xi32>
-// CHECK-DAG:        [[RHSHI:%.+]]   = spirv.ShiftRightLogical [[B]], [[CST16]] : vector<3xi32>
-// CHECK-DAG:        [[LOW:%.+]]     = spirv.IAdd [[LHSLOW]], [[RHSLOW]] : vector<3xi32>
-// CHECK-DAG:        [[HI:%.+]]      = spirv.IAdd [[LHSHI]], [[RHSHI]]
-// CHECK-DAG:        [[LOWCRY:%.+]]  = spirv.ShiftRightLogical [[LOW]], [[CST16]] : vector<3xi32>
-// CHECK-DAG:        [[HI_TTL:%.+]]  = spirv.IAdd [[HI]], [[LOWCRY]]
-// CHECK-DAG:                          spirv.ShiftRightLogical
-// CHECK-DAG:                          spirv.BitwiseAnd
-// CHECK-DAG:                          spirv.BitwiseAnd
-// CHECK-DAG:                          spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : vector<3xi32>
-// CHECK-DAG:                          spirv.BitwiseOr
-// CHECK-NEXT:       [[RES:%.+]]     = spirv.CompositeConstruct [[RESLO:%.+]], [[RESHI:%.+]] : (vector<3xi32>, vector<3xi32>) -> !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+// CHECK-NEXT:       [[ONE:%.+]]    = spirv.Constant dense<1> : vector<3xi32>
+// CHECK-NEXT:       [[ZERO:%.+]]   = spirv.Constant dense<0> : vector<3xi32>
+// CHECK-NEXT:       [[OUT:%.+]]    = spirv.IAdd [[A]], [[B]]
+// CHECK-NEXT:       [[CMP:%.+]]    = spirv.ULessThan [[OUT]], [[A]]
+// CHECK-NEXT:       [[CARRY:%.+]]  = spirv.Select [[CMP]], [[ONE]], [[ZERO]]
+// CHECK-NEXT:       [[RES:%.+]]    = spirv.CompositeConstruct [[OUT]], [[CARRY]] : (vector<3xi32>, vector<3xi32>) -> !spirv.struct<(vector<3xi32>, vector<3xi32>)>
 // CHECK-NEXT:       spirv.ReturnValue [[RES]] : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
 spirv.func @iaddcarry_vector_i32(%a : vector<3xi32>, %b : vector<3xi32>)
   -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> "None" {

>From 9ed4b59c37d0a630f89c2263062ff01d39331c36 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Thu, 12 Oct 2023 14:37:35 +0200
Subject: [PATCH 3/3] review comments

- clarify description comment
- add integration test for the vulkan runner
---
 .../Transforms/SPIRVWebGPUTransforms.cpp      |  3 +-
 .../iaddcarry_extended.mlir                   | 74 +++++++++++++++++++
 2 files changed, 75 insertions(+), 2 deletions(-)
 create mode 100644 mlir/test/mlir-vulkan-runner/iaddcarry_extended.mlir

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
index 9a780608f1ebbde..21de1c9e867c04e 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
@@ -190,8 +190,7 @@ struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
     Value zero =
         rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 0));
 
-    // Emulate 64-bit unsigned addition by allowing our addition to overflow,
-    // and then set the carry accordingly.
+    // Calculate the carry by checking if the addition resulted in an overflow.
     Value out = rewriter.create<IAddOp>(loc, lhs, rhs);
     Value cmp = rewriter.create<ULessThanOp>(loc, out, lhs);
     Value carry = rewriter.create<SelectOp>(loc, cmp, one, zero);
diff --git a/mlir/test/mlir-vulkan-runner/iaddcarry_extended.mlir b/mlir/test/mlir-vulkan-runner/iaddcarry_extended.mlir
new file mode 100644
index 000000000000000..381a73ead03dde9
--- /dev/null
+++ b/mlir/test/mlir-vulkan-runner/iaddcarry_extended.mlir
@@ -0,0 +1,74 @@
+// Make sure that addition with carry produces expected results
+// with and without expansion to primitive add/cmp ops for WebGPU.
+
+// RUN: mlir-vulkan-runner %s \
+// RUN:  --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils \
+// RUN:  --entry-point-result=void | FileCheck %s
+
+// RUN: mlir-vulkan-runner %s --vulkan-runner-spirv-webgpu-prepare \
+// RUN:  --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils \
+// RUN:  --entry-point-result=void | FileCheck %s
+
+// CHECK: [0, 1, 0, 42]
+// CHECK: [0, 0, 1, 1]
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+} {
+  gpu.module @kernels {
+    gpu.func @kernel_add(%arg0 : memref<4xi32>, %arg1 : memref<4xi32>, %arg2 : memref<4xi32>, %arg3 : memref<4xi8>)
+      kernel attributes { spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
+      %0 = gpu.block_id x
+      %lhs = memref.load %arg0[%0] : memref<4xi32>
+      %rhs = memref.load %arg1[%0] : memref<4xi32>
+      %sum, %carry = arith.addui_extended %lhs, %rhs : i32
+
+      // We will convert to i8 as this is the smallest value we can use for
+      // fill/print in the runner.
+      %carry_i8 = arith.extui %carry : i1 to i8
+
+      memref.store %sum, %arg2[%0] : memref<4xi32>
+      memref.store %carry_i8, %arg3[%0] : memref<4xi8>
+      gpu.return
+    }
+  }
+
+  func.func @main() {
+    %buf0 = memref.alloc() : memref<4xi32>
+    %buf1 = memref.alloc() : memref<4xi32>
+    %buf2 = memref.alloc() : memref<4xi32>
+    %buf3 = memref.alloc() : memref<4xi8>
+    %i32_0 = arith.constant 0 : i32
+    %i8_0 = arith.constant 0 : i32
+
+    // Initialize output buffers.
+    %buf4 = memref.cast %buf2 : memref<4xi32> to memref<?xi32>
+    %buf5 = memref.cast %buf3 : memref<4xi8> to memref<?xi8>
+    call @fillResource1DInt(%buf4, %i32_0) : (memref<?xi32>, i32) -> ()
+    call @fillResource1DInt8(%buf5, %i8_0) : (memref<?xi32>, i32) -> ()
+
+    %idx_0 = arith.constant 0 : index
+    %idx_1 = arith.constant 1 : index
+    %idx_4 = arith.constant 4 : index
+
+    // Initialize input buffers.
+    %lhs_vals = arith.constant dense<[0, 0, -2147483647, 43]> : vector<4xi32>
+    %rhs_vals = arith.constant dense<[0, 1, 1, -2147483647]> : vector<4xi32>
+    vector.store %lhs_vals, %buf0[%idx_0] : memref<4xi32>, vector<4xi32>
+    vector.store %rhs_vals, %buf1[%idx_0] : memref<4xi32>, vector<4xi32>
+
+    gpu.launch_func @kernels::@kernel_add
+        blocks in (%idx_4, %idx_1, %idx_1) threads in (%idx_1, %idx_1, %idx_1)
+        args(%buf0 : memref<4xi32>, %buf1 : memref<4xi32>, %buf2 : memref<4xi32>, %buf3 : memref<4xi8>)
+    %buf_sum = memref.cast %buf4 : memref<?xi32> to memref<*xi32>
+    %buf_carry = memref.cast %buf5 : memref<?xi32> to memref<*xi8>
+    call @printMemrefI32(%buf_sum) : (memref<*xi32>) -> ()
+    call @printMemrefI32(%buf_carry) : (memref<*xi8>) -> ()
+    return
+  }
+  func.func private @fillResource1DInt8(%0 : memref<?xi8>, %1 : i8)
+  func.func private @fillResource1DInt(%0 : memref<?xi32>, %1 : i32)
+  func.func private @printMemrefI8(%ptr : memref<*xi8>)
+  func.func private @printMemrefI32(%ptr : memref<*xi32>)
+}



More information about the Mlir-commits mailing list