[llvm-branch-commits] [mlir] 1d0dc9b - [MLIR][SPIRV] Add rewrite pattern to convert select+cmp into GLSL clamp.
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Dec 23 06:52:23 PST 2020
Author: ergawy
Date: 2020-12-23T15:47:19+01:00
New Revision: 1d0dc9be6d72915d2bb632c7a46645289405dcbf
URL: https://github.com/llvm/llvm-project/commit/1d0dc9be6d72915d2bb632c7a46645289405dcbf
DIFF: https://github.com/llvm/llvm-project/commit/1d0dc9be6d72915d2bb632c7a46645289405dcbf.diff
LOG: [MLIR][SPIRV] Add rewrite pattern to convert select+cmp into GLSL clamp.
Adds rewrite patterns to convert select+cmp instructions into clamp
instructions whenever possible. Support is added to convert:
- FOrdLessThan, FOrdLessThanEqual to GLSLFClampOp.
- SLessThan, SLessThanEqual to GLSLSClampOp.
- ULessThan, ULessThanEqual to GLSLUClampOp.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D93618
Added:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h
mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp
mlir/test/Dialect/SPIRV/Transforms/glsl_canonicalize.mlir
mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp
Modified:
mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td
mlir/test/lib/Dialect/SPIRV/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h
new file mode 100644
index 000000000000..1921dbbcfc70
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h
@@ -0,0 +1,31 @@
+//===- SPIRVGLSLCanonicalization.h - GLSL-specific patterns -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares a function to register SPIR-V GLSL-specific
+// canonicalization patterns.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_SPIRVGLSLCANONICALIZATION_H_
+#define MLIR_DIALECT_SPIRV_IR_SPIRVGLSLCANONICALIZATION_H_
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+
+//===----------------------------------------------------------------------===//
+// GLSL canonicalization patterns
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace spirv {
+void populateSPIRVGLSLCanonicalizationPatterns(
+ mlir::OwningRewritePatternList &results, mlir::MLIRContext *context);
+} // namespace spirv
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SPIRV_IR_SPIRVGLSLCANONICALIZATION_H_
diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index dbf62425878b..42c0047168b9 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -5,6 +5,7 @@ add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen)
add_mlir_dialect_library(MLIRSPIRV
SPIRVAttributes.cpp
SPIRVCanonicalization.cpp
+ SPIRVGLSLCanonicalization.cpp
SPIRVDialect.cpp
SPIRVEnums.cpp
SPIRVOps.cpp
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td
index 125e97360865..6e0ee4488fa1 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td
@@ -38,3 +38,33 @@ def ConvertLogicalNotOfLogicalEqual : Pat<
def ConvertLogicalNotOfLogicalNotEqual : Pat<
(SPV_LogicalNotOp (SPV_LogicalNotEqualOp $lhs, $rhs)),
(SPV_LogicalEqualOp $lhs, $rhs)>;
+
+//===----------------------------------------------------------------------===//
+// Re-write spv.Select + spv.<less_than_op> to a suitable variant of
+// spv.<glsl_clamp_op>
+//===----------------------------------------------------------------------===//
+
+def ValuesAreEqual : Constraint<CPred<"$0 == $1">>;
+
+foreach CmpClampPair = [
+ [SPV_FOrdLessThanOp, SPV_GLSLFClampOp],
+ [SPV_FOrdLessThanEqualOp, SPV_GLSLFClampOp],
+ [SPV_SLessThanOp, SPV_GLSLSClampOp],
+ [SPV_SLessThanEqualOp, SPV_GLSLSClampOp],
+ [SPV_ULessThanOp, SPV_GLSLUClampOp],
+ [SPV_ULessThanEqualOp, SPV_GLSLUClampOp]] in {
+def ConvertComparisonIntoClamp#CmpClampPair[0] : Pat<
+ (SPV_SelectOp
+ (CmpClampPair[0]
+ (SPV_SelectOp:$middle0
+ (CmpClampPair[0] $min, $input),
+ $input,
+ $min
+ ),
+ $max
+ ),
+ $middle1,
+ $max),
+ (CmpClampPair[1] $input, $min, $max),
+ [(ValuesAreEqual $middle0, $middle1)]>;
+}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp
new file mode 100644
index 000000000000..0aa413941efd
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp
@@ -0,0 +1,35 @@
+//===- SPIRVGLSLCanonicalization.cpp - SPIR-V GLSL canonicalization patterns =//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the canonicalization patterns for SPIR-V GLSL-specific ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h"
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+using namespace mlir;
+
+namespace {
+#include "SPIRVCanonicalization.inc"
+} // end anonymous namespace
+
+namespace mlir {
+namespace spirv {
+void populateSPIRVGLSLCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<ConvertComparisonIntoClampSPV_FOrdLessThanOp,
+ ConvertComparisonIntoClampSPV_FOrdLessThanEqualOp,
+ ConvertComparisonIntoClampSPV_SLessThanOp,
+ ConvertComparisonIntoClampSPV_SLessThanEqualOp,
+ ConvertComparisonIntoClampSPV_ULessThanOp,
+ ConvertComparisonIntoClampSPV_ULessThanEqualOp>(context);
+}
+} // namespace spirv
+} // namespace mlir
diff --git a/mlir/test/Dialect/SPIRV/Transforms/glsl_canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/glsl_canonicalize.mlir
new file mode 100644
index 000000000000..90e9b85b9035
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Transforms/glsl_canonicalize.mlir
@@ -0,0 +1,113 @@
+// RUN: mlir-opt -test-spirv-glsl-canonicalization -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK: func @clamp_fordlessthan(%[[INPUT:.*]]: f32)
+func @clamp_fordlessthan(%input: f32) -> f32 {
+ // CHECK: %[[MIN:.*]] = spv.constant
+ %min = spv.constant 0.5 : f32
+ // CHECK: %[[MAX:.*]] = spv.constant
+ %max = spv.constant 1.0 : f32
+
+ // CHECK: [[RES:%.*]] = spv.GLSL.FClamp %[[INPUT]], %[[MIN]], %[[MAX]]
+ %0 = spv.FOrdLessThan %min, %input : f32
+ %mid = spv.Select %0, %input, %min : i1, f32
+ %1 = spv.FOrdLessThan %mid, %max : f32
+ %2 = spv.Select %1, %mid, %max : i1, f32
+
+ // CHECK-NEXT: spv.ReturnValue [[RES]]
+ spv.ReturnValue %2 : f32
+}
+
+// -----
+
+// CHECK: func @clamp_fordlessthanequal(%[[INPUT:.*]]: f32)
+func @clamp_fordlessthanequal(%input: f32) -> f32 {
+ // CHECK: %[[MIN:.*]] = spv.constant
+ %min = spv.constant 0.5 : f32
+ // CHECK: %[[MAX:.*]] = spv.constant
+ %max = spv.constant 1.0 : f32
+
+ // CHECK: [[RES:%.*]] = spv.GLSL.FClamp %[[INPUT]], %[[MIN]], %[[MAX]]
+ %0 = spv.FOrdLessThanEqual %min, %input : f32
+ %mid = spv.Select %0, %input, %min : i1, f32
+ %1 = spv.FOrdLessThanEqual %mid, %max : f32
+ %2 = spv.Select %1, %mid, %max : i1, f32
+
+ // CHECK-NEXT: spv.ReturnValue [[RES]]
+ spv.ReturnValue %2 : f32
+}
+
+// -----
+
+// CHECK: func @clamp_slessthan(%[[INPUT:.*]]: si32)
+func @clamp_slessthan(%input: si32) -> si32 {
+ // CHECK: %[[MIN:.*]] = spv.constant
+ %min = spv.constant 0 : si32
+ // CHECK: %[[MAX:.*]] = spv.constant
+ %max = spv.constant 10 : si32
+
+ // CHECK: [[RES:%.*]] = spv.GLSL.SClamp %[[INPUT]], %[[MIN]], %[[MAX]]
+ %0 = spv.SLessThan %min, %input : si32
+ %mid = spv.Select %0, %input, %min : i1, si32
+ %1 = spv.SLessThan %mid, %max : si32
+ %2 = spv.Select %1, %mid, %max : i1, si32
+
+ // CHECK-NEXT: spv.ReturnValue [[RES]]
+ spv.ReturnValue %2 : si32
+}
+
+// -----
+
+// CHECK: func @clamp_slessthanequal(%[[INPUT:.*]]: si32)
+func @clamp_slessthanequal(%input: si32) -> si32 {
+ // CHECK: %[[MIN:.*]] = spv.constant
+ %min = spv.constant 0 : si32
+ // CHECK: %[[MAX:.*]] = spv.constant
+ %max = spv.constant 10 : si32
+
+ // CHECK: [[RES:%.*]] = spv.GLSL.SClamp %[[INPUT]], %[[MIN]], %[[MAX]]
+ %0 = spv.SLessThanEqual %min, %input : si32
+ %mid = spv.Select %0, %input, %min : i1, si32
+ %1 = spv.SLessThanEqual %mid, %max : si32
+ %2 = spv.Select %1, %mid, %max : i1, si32
+
+ // CHECK-NEXT: spv.ReturnValue [[RES]]
+ spv.ReturnValue %2 : si32
+}
+
+// -----
+
+// CHECK: func @clamp_ulessthan(%[[INPUT:.*]]: i32)
+func @clamp_ulessthan(%input: i32) -> i32 {
+ // CHECK: %[[MIN:.*]] = spv.constant
+ %min = spv.constant 0 : i32
+ // CHECK: %[[MAX:.*]] = spv.constant
+ %max = spv.constant 10 : i32
+
+ // CHECK: [[RES:%.*]] = spv.GLSL.UClamp %[[INPUT]], %[[MIN]], %[[MAX]]
+ %0 = spv.ULessThan %min, %input : i32
+ %mid = spv.Select %0, %input, %min : i1, i32
+ %1 = spv.ULessThan %mid, %max : i32
+ %2 = spv.Select %1, %mid, %max : i1, i32
+
+ // CHECK-NEXT: spv.ReturnValue [[RES]]
+ spv.ReturnValue %2 : i32
+}
+
+// -----
+
+// CHECK: func @clamp_ulessthanequal(%[[INPUT:.*]]: i32)
+func @clamp_ulessthanequal(%input: i32) -> i32 {
+ // CHECK: %[[MIN:.*]] = spv.constant
+ %min = spv.constant 0 : i32
+ // CHECK: %[[MAX:.*]] = spv.constant
+ %max = spv.constant 10 : i32
+
+ // CHECK: [[RES:%.*]] = spv.GLSL.UClamp %[[INPUT]], %[[MIN]], %[[MAX]]
+ %0 = spv.ULessThanEqual %min, %input : i32
+ %mid = spv.Select %0, %input, %min : i1, i32
+ %1 = spv.ULessThanEqual %mid, %max : i32
+ %2 = spv.Select %1, %mid, %max : i1, i32
+
+ // CHECK-NEXT: spv.ReturnValue [[RES]]
+ spv.ReturnValue %2 : i32
+}
diff --git a/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt
index edcbf4ebf1de..856e5eb7f40d 100644
--- a/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt
@@ -2,6 +2,7 @@
add_mlir_library(MLIRSPIRVTestPasses
TestAvailability.cpp
TestEntryPointAbi.cpp
+ TestGLSLCanonicalization.cpp
TestModuleCombiner.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp b/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp
new file mode 100644
index 000000000000..158601fbc17a
--- /dev/null
+++ b/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp
@@ -0,0 +1,39 @@
+//===- TestGLSLCanonicalization.cpp - Pass to test GLSL-specific pattterns ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVModule.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+class TestGLSLCanonicalizationPass
+ : public PassWrapper<TestGLSLCanonicalizationPass,
+ OperationPass<mlir::ModuleOp>> {
+public:
+ TestGLSLCanonicalizationPass() = default;
+ TestGLSLCanonicalizationPass(const TestGLSLCanonicalizationPass &) {}
+ void runOnOperation() override;
+};
+} // namespace
+
+void TestGLSLCanonicalizationPass::runOnOperation() {
+ OwningRewritePatternList patterns;
+ spirv::populateSPIRVGLSLCanonicalizationPatterns(patterns, &getContext());
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+namespace mlir {
+void registerTestSpirvGLSLCanonicalizationPass() {
+ PassRegistration<TestGLSLCanonicalizationPass> registration(
+ "test-spirv-glsl-canonicalization",
+ "Tests SPIR-V canonicalization patterns for GLSL extension.");
+}
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 67aa855092ef..dc68f8f4d778 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -47,6 +47,7 @@ void registerTestPrintDefUsePass();
void registerTestPrintNestingPass();
void registerTestReducer();
void registerTestSpirvEntryPointABIPass();
+void registerTestSpirvGLSLCanonicalizationPass();
void registerTestSpirvModuleCombinerPass();
void registerTestTraitsPass();
void registerTosaTestQuantUtilAPIPass();
@@ -115,6 +116,7 @@ void registerTestPasses() {
registerTestPrintNestingPass();
registerTestReducer();
registerTestSpirvEntryPointABIPass();
+ registerTestSpirvGLSLCanonicalizationPass();
registerTestSpirvModuleCombinerPass();
registerTestTraitsPass();
registerVectorizerTestPass();
More information about the llvm-branch-commits
mailing list