[Mlir-commits] [mlir] [mlir][spirv] Add IsInf/IsNan expansion for WebGPU (PR #86903)

Jakub Kuderski llvmlistbot at llvm.org
Wed Mar 27 19:15:39 PDT 2024


https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/86903

These non-finite math ops are supported by SPIR-V but not by WGSL. Assume finite floating point values and expand these ops into `false`.

Previously, this worked by adding fast math flags during conversion from arith to spirv, but this got removed in
https://github.com/llvm/llvm-project/pull/86578.

Also do some misc cleanups in the surrounding code.

>From 17d951cd1aa55a8ea844bf3cd5d8d5ea1aa35c1b Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 27 Mar 2024 22:11:31 -0400
Subject: [PATCH] [mlir][spirv] Add IsInf/IsNan expansion for WebGPU

These non-finite math ops are supported by SPIR-V but not by WGSL.
Assume finite floating point values and expand these ops into `false`.

Previously, this worked by adding fast math flags during conversion from
arith to spirv, but this got removed in
https://github.com/llvm/llvm-project/pull/86578.
---
 .../SPIRV/Transforms/SPIRVWebGPUTransforms.h  | 12 +++--
 .../Transforms/SPIRVWebGPUTransforms.cpp      | 54 ++++++++++++++-----
 .../SPIRV/Transforms/webgpu-prepare.mlir      | 32 +++++++++++
 3 files changed, 82 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h
index ac4d38e0c5b1eb..d0fc85ccc9de49 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h
@@ -18,12 +18,18 @@
 namespace mlir {
 namespace spirv {
 
-/// Appends to a pattern list additional patterns to expand extended
-/// multiplication ops into regular arithmetic ops. Extended multiplication ops
-/// are not supported by the WebGPU Shading Language (WGSL).
+/// Appends patterns to expand extended multiplication and adition ops into
+/// regular arithmetic ops. Extended arithmetic ops are not supported by the
+/// WebGPU Shading Language (WGSL).
 void populateSPIRVExpandExtendedMultiplicationPatterns(
     RewritePatternSet &patterns);
 
+/// Appends patterns to expand non-finite arithmetic ops `IsNan` and `IsInf`.
+/// These are not supported by the WebGPU Shading Language (WGSL). We follow
+/// fast math assumptions and assume that all floating point values are finite.
+void populateSPIRVExpandNonFiniteArithmeticPatterns(
+    RewritePatternSet &patterns);
+
 } // namespace spirv
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
index 21de1c9e867c04..5d4dd5b3a1e013 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
@@ -39,7 +39,7 @@ namespace {
 //===----------------------------------------------------------------------===//
 // Helpers
 //===----------------------------------------------------------------------===//
-Attribute getScalarOrSplatAttr(Type type, int64_t value) {
+static Attribute getScalarOrSplatAttr(Type type, int64_t value) {
   APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value);
   if (auto intTy = dyn_cast<IntegerType>(type))
     return IntegerAttr::get(intTy, sizedValue);
@@ -47,9 +47,9 @@ Attribute getScalarOrSplatAttr(Type type, int64_t value) {
   return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
 }
 
-Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter,
-                                  Value lhs, Value rhs,
-                                  bool signExtendArguments) {
+static Value lowerExtendedMultiplication(Operation *mulOp,
+                                         PatternRewriter &rewriter, Value lhs,
+                                         Value rhs, bool signExtendArguments) {
   Location loc = mulOp->getLoc();
   Type argTy = lhs.getType();
   // Emulate 64-bit multiplication by splitting each input element of type i32
@@ -203,15 +203,39 @@ struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
   }
 };
 
+struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IsInfOp op,
+                                PatternRewriter &rewriter) const override {
+    // We assume values to be finite and turn `IsInf` info `false`.
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+        op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
+    return success();
+  }
+};
+
+struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IsNanOp op,
+                                PatternRewriter &rewriter) const override {
+    // We assume values to be finite and turn `IsNan` info `false`.
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+        op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Passes
 //===----------------------------------------------------------------------===//
-class WebGPUPreparePass
-    : public impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
-public:
+struct WebGPUPreparePass final
+    : impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
     populateSPIRVExpandExtendedMultiplicationPatterns(patterns);
+    populateSPIRVExpandNonFiniteArithmeticPatterns(patterns);
 
     if (failed(
             applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
@@ -227,12 +251,16 @@ void populateSPIRVExpandExtendedMultiplicationPatterns(
     RewritePatternSet &patterns) {
   // WGSL currently does not support extended multiplication ops, see:
   // https://github.com/gpuweb/gpuweb/issues/1565.
-  patterns.add<
-      // clang-format off
-    ExpandSMulExtendedPattern,
-    ExpandUMulExtendedPattern,
-    ExpandAddCarryPattern
-  >(patterns.getContext());
+  patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
+               ExpandAddCarryPattern>(patterns.getContext());
 }
+
+void populateSPIRVExpandNonFiniteArithmeticPatterns(
+    RewritePatternSet &patterns) {
+  // WGSL currently does not support `isInf` and `isNan`, see:
+  // https://github.com/gpuweb/gpuweb/pull/2311.
+  patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.getContext());
+}
+
 } // namespace spirv
 } // namespace mlir
diff --git a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
index 1ec4e5e4f9664b..45f188da3815cf 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
@@ -182,4 +182,36 @@ spirv.func @iaddcarry_i16(%a : i16, %b : i16) -> !spirv.struct<(i16, i16)> "None
   spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
 }
 
+// CHECK-LABEL: func @is_inf_f32
+// CHECK-NEXT:       [[FALSE:%.+]] = spirv.Constant false
+// CHECK-NEXT:       spirv.ReturnValue [[FALSE]] : i1
+spirv.func @is_inf_f32(%a : f32) -> i1 "None" {
+  %0 = spirv.IsInf %a : f32
+  spirv.ReturnValue %0 : i1
+}
+
+// CHECK-LABEL: func @is_inf_4xf32
+// CHECK-NEXT:       [[FALSE:%.+]] = spirv.Constant dense<false> : vector<4xi1>
+// CHECK-NEXT:       spirv.ReturnValue [[FALSE]] : vector<4xi1>
+spirv.func @is_inf_4xf32(%a : vector<4xf32>) -> vector<4xi1> "None" {
+  %0 = spirv.IsInf %a : vector<4xf32>
+  spirv.ReturnValue %0 : vector<4xi1>
+}
+
+// CHECK-LABEL: func @is_nan_f32
+// CHECK-NEXT:       [[FALSE:%.+]] = spirv.Constant false
+// CHECK-NEXT:       spirv.ReturnValue [[FALSE]] : i1
+spirv.func @is_nan_f32(%a : f32) -> i1 "None" {
+  %0 = spirv.IsNan %a : f32
+  spirv.ReturnValue %0 : i1
+}
+
+// CHECK-LABEL: func @is_nan_4xf32
+// CHECK-NEXT:       [[FALSE:%.+]] = spirv.Constant dense<false> : vector<4xi1>
+// CHECK-NEXT:       spirv.ReturnValue [[FALSE]] : vector<4xi1>
+spirv.func @is_nan_4xf32(%a : vector<4xf32>) -> vector<4xi1> "None" {
+  %0 = spirv.IsNan %a : vector<4xf32>
+  spirv.ReturnValue %0 : vector<4xi1>
+}
+
 } // end module



More information about the Mlir-commits mailing list