[Mlir-commits] [mlir] [mlir][spirv] Add IsInf/IsNan expansion for WebGPU (PR #86903)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 27 19:16:10 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Jakub Kuderski (kuhar)
<details>
<summary>Changes</summary>
These non-finite math ops are supported by SPIR-V but not by WGSL. Assume finite floating point values and expand these ops into `false`.
Previously, this worked by adding fast math flags during conversion from arith to spirv, but this got removed in
https://github.com/llvm/llvm-project/pull/86578.
Also do some misc cleanups in the surrounding code.
---
Full diff: https://github.com/llvm/llvm-project/pull/86903.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h (+9-3)
- (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp (+41-13)
- (modified) mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir (+32)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h
index ac4d38e0c5b1eb..d0fc85ccc9de49 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h
@@ -18,12 +18,18 @@
namespace mlir {
namespace spirv {
-/// Appends to a pattern list additional patterns to expand extended
-/// multiplication ops into regular arithmetic ops. Extended multiplication ops
-/// are not supported by the WebGPU Shading Language (WGSL).
+/// Appends patterns to expand extended multiplication and adition ops into
+/// regular arithmetic ops. Extended arithmetic ops are not supported by the
+/// WebGPU Shading Language (WGSL).
void populateSPIRVExpandExtendedMultiplicationPatterns(
RewritePatternSet &patterns);
+/// Appends patterns to expand non-finite arithmetic ops `IsNan` and `IsInf`.
+/// These are not supported by the WebGPU Shading Language (WGSL). We follow
+/// fast math assumptions and assume that all floating point values are finite.
+void populateSPIRVExpandNonFiniteArithmeticPatterns(
+ RewritePatternSet &patterns);
+
} // namespace spirv
} // namespace mlir
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
index 21de1c9e867c04..5d4dd5b3a1e013 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
@@ -39,7 +39,7 @@ namespace {
//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//
-Attribute getScalarOrSplatAttr(Type type, int64_t value) {
+static Attribute getScalarOrSplatAttr(Type type, int64_t value) {
APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value);
if (auto intTy = dyn_cast<IntegerType>(type))
return IntegerAttr::get(intTy, sizedValue);
@@ -47,9 +47,9 @@ Attribute getScalarOrSplatAttr(Type type, int64_t value) {
return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
}
-Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter,
- Value lhs, Value rhs,
- bool signExtendArguments) {
+static Value lowerExtendedMultiplication(Operation *mulOp,
+ PatternRewriter &rewriter, Value lhs,
+ Value rhs, bool signExtendArguments) {
Location loc = mulOp->getLoc();
Type argTy = lhs.getType();
// Emulate 64-bit multiplication by splitting each input element of type i32
@@ -203,15 +203,39 @@ struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
}
};
+struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IsInfOp op,
+ PatternRewriter &rewriter) const override {
+ // We assume values to be finite and turn `IsInf` info `false`.
+ rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+ op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
+ return success();
+ }
+};
+
+struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IsNanOp op,
+ PatternRewriter &rewriter) const override {
+ // We assume values to be finite and turn `IsNan` info `false`.
+ rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+ op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
-class WebGPUPreparePass
- : public impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
-public:
+struct WebGPUPreparePass final
+ : impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateSPIRVExpandExtendedMultiplicationPatterns(patterns);
+ populateSPIRVExpandNonFiniteArithmeticPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
@@ -227,12 +251,16 @@ void populateSPIRVExpandExtendedMultiplicationPatterns(
RewritePatternSet &patterns) {
// WGSL currently does not support extended multiplication ops, see:
// https://github.com/gpuweb/gpuweb/issues/1565.
- patterns.add<
- // clang-format off
- ExpandSMulExtendedPattern,
- ExpandUMulExtendedPattern,
- ExpandAddCarryPattern
- >(patterns.getContext());
+ patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
+ ExpandAddCarryPattern>(patterns.getContext());
}
+
+void populateSPIRVExpandNonFiniteArithmeticPatterns(
+ RewritePatternSet &patterns) {
+ // WGSL currently does not support `isInf` and `isNan`, see:
+ // https://github.com/gpuweb/gpuweb/pull/2311.
+ patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.getContext());
+}
+
} // namespace spirv
} // namespace mlir
diff --git a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
index 1ec4e5e4f9664b..45f188da3815cf 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
@@ -182,4 +182,36 @@ spirv.func @iaddcarry_i16(%a : i16, %b : i16) -> !spirv.struct<(i16, i16)> "None
spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
}
+// CHECK-LABEL: func @is_inf_f32
+// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant false
+// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : i1
+spirv.func @is_inf_f32(%a : f32) -> i1 "None" {
+ %0 = spirv.IsInf %a : f32
+ spirv.ReturnValue %0 : i1
+}
+
+// CHECK-LABEL: func @is_inf_4xf32
+// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant dense<false> : vector<4xi1>
+// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : vector<4xi1>
+spirv.func @is_inf_4xf32(%a : vector<4xf32>) -> vector<4xi1> "None" {
+ %0 = spirv.IsInf %a : vector<4xf32>
+ spirv.ReturnValue %0 : vector<4xi1>
+}
+
+// CHECK-LABEL: func @is_nan_f32
+// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant false
+// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : i1
+spirv.func @is_nan_f32(%a : f32) -> i1 "None" {
+ %0 = spirv.IsNan %a : f32
+ spirv.ReturnValue %0 : i1
+}
+
+// CHECK-LABEL: func @is_nan_4xf32
+// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant dense<false> : vector<4xi1>
+// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : vector<4xi1>
+spirv.func @is_nan_4xf32(%a : vector<4xf32>) -> vector<4xi1> "None" {
+ %0 = spirv.IsNan %a : vector<4xf32>
+ spirv.ReturnValue %0 : vector<4xi1>
+}
+
} // end module
``````````
</details>
https://github.com/llvm/llvm-project/pull/86903
More information about the Mlir-commits
mailing list