[llvm-branch-commits] [mlir] 1d0dc9b - [MLIR][SPIRV] Add rewrite pattern to convert select+cmp into GLSL clamp.

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Dec 23 06:52:23 PST 2020


Author: ergawy
Date: 2020-12-23T15:47:19+01:00
New Revision: 1d0dc9be6d72915d2bb632c7a46645289405dcbf

URL: https://github.com/llvm/llvm-project/commit/1d0dc9be6d72915d2bb632c7a46645289405dcbf
DIFF: https://github.com/llvm/llvm-project/commit/1d0dc9be6d72915d2bb632c7a46645289405dcbf.diff

LOG: [MLIR][SPIRV] Add rewrite pattern to convert select+cmp into GLSL clamp.

Adds rewrite patterns to convert select+cmp instructions into clamp
instructions whenever possible. Support is added to convert:

- FOrdLessThan, FOrdLessThanEqual to GLSLFClampOp.
- SLessThan, SLessThanEqual to GLSLSClampOp.
- ULessThan, ULessThanEqual to GLSLUClampOp.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D93618

Added: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h
    mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp
    mlir/test/Dialect/SPIRV/Transforms/glsl_canonicalize.mlir
    mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp

Modified: 
    mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
    mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td
    mlir/test/lib/Dialect/SPIRV/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h
new file mode 100644
index 000000000000..1921dbbcfc70
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h
@@ -0,0 +1,31 @@
+//===- SPIRVGLSLCanonicalization.h - GLSL-specific patterns -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares a function to register SPIR-V GLSL-specific
+// canonicalization patterns.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_SPIRVGLSLCANONICALIZATION_H_
+#define MLIR_DIALECT_SPIRV_IR_SPIRVGLSLCANONICALIZATION_H_
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+
+//===----------------------------------------------------------------------===//
+// GLSL canonicalization patterns
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace spirv {
+void populateSPIRVGLSLCanonicalizationPatterns(
+    mlir::OwningRewritePatternList &results, mlir::MLIRContext *context);
+} // namespace spirv
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SPIRV_IR_SPIRVGLSLCANONICALIZATION_H_

diff  --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index dbf62425878b..42c0047168b9 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -5,6 +5,7 @@ add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen)
 add_mlir_dialect_library(MLIRSPIRV
   SPIRVAttributes.cpp
   SPIRVCanonicalization.cpp
+  SPIRVGLSLCanonicalization.cpp
   SPIRVDialect.cpp
   SPIRVEnums.cpp
   SPIRVOps.cpp

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td
index 125e97360865..6e0ee4488fa1 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td
@@ -38,3 +38,33 @@ def ConvertLogicalNotOfLogicalEqual : Pat<
 def ConvertLogicalNotOfLogicalNotEqual : Pat<
     (SPV_LogicalNotOp (SPV_LogicalNotEqualOp $lhs, $rhs)),
     (SPV_LogicalEqualOp $lhs, $rhs)>;
+
+//===----------------------------------------------------------------------===//
+// Re-write spv.Select + spv.<less_than_op> to a suitable variant of
+// spv.<glsl_clamp_op>
+//===----------------------------------------------------------------------===//
+
+def ValuesAreEqual : Constraint<CPred<"$0 == $1">>;
+
+foreach CmpClampPair = [
+    [SPV_FOrdLessThanOp, SPV_GLSLFClampOp],
+    [SPV_FOrdLessThanEqualOp, SPV_GLSLFClampOp],
+    [SPV_SLessThanOp, SPV_GLSLSClampOp],
+    [SPV_SLessThanEqualOp, SPV_GLSLSClampOp],
+    [SPV_ULessThanOp, SPV_GLSLUClampOp],
+    [SPV_ULessThanEqualOp, SPV_GLSLUClampOp]] in {
+def ConvertComparisonIntoClamp#CmpClampPair[0] : Pat<
+    (SPV_SelectOp
+        (CmpClampPair[0]
+            (SPV_SelectOp:$middle0
+                (CmpClampPair[0] $min, $input),
+                $input,
+                $min
+            ),
+            $max
+        ),
+        $middle1,
+        $max),
+    (CmpClampPair[1] $input, $min, $max),
+    [(ValuesAreEqual $middle0, $middle1)]>;
+}

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp
new file mode 100644
index 000000000000..0aa413941efd
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp
@@ -0,0 +1,35 @@
+//===- SPIRVGLSLCanonicalization.cpp - SPIR-V GLSL canonicalization patterns =//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the canonicalization patterns for SPIR-V GLSL-specific ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h"
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+using namespace mlir;
+
+namespace {
+#include "SPIRVCanonicalization.inc"
+} // end anonymous namespace
+
+namespace mlir {
+namespace spirv {
+void populateSPIRVGLSLCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<ConvertComparisonIntoClampSPV_FOrdLessThanOp,
+                 ConvertComparisonIntoClampSPV_FOrdLessThanEqualOp,
+                 ConvertComparisonIntoClampSPV_SLessThanOp,
+                 ConvertComparisonIntoClampSPV_SLessThanEqualOp,
+                 ConvertComparisonIntoClampSPV_ULessThanOp,
+                 ConvertComparisonIntoClampSPV_ULessThanEqualOp>(context);
+}
+} // namespace spirv
+} // namespace mlir

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/glsl_canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/glsl_canonicalize.mlir
new file mode 100644
index 000000000000..90e9b85b9035
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Transforms/glsl_canonicalize.mlir
@@ -0,0 +1,113 @@
+// RUN: mlir-opt -test-spirv-glsl-canonicalization -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK: func @clamp_fordlessthan(%[[INPUT:.*]]: f32)
+func @clamp_fordlessthan(%input: f32) -> f32 {
+  // CHECK: %[[MIN:.*]] = spv.constant
+  %min = spv.constant 0.5 : f32
+  // CHECK: %[[MAX:.*]] = spv.constant
+  %max = spv.constant 1.0 : f32
+
+  // CHECK: [[RES:%.*]] = spv.GLSL.FClamp %[[INPUT]], %[[MIN]], %[[MAX]]
+  %0 = spv.FOrdLessThan %min, %input : f32
+  %mid = spv.Select %0, %input, %min : i1, f32
+  %1 = spv.FOrdLessThan %mid, %max : f32
+  %2 = spv.Select %1, %mid, %max : i1, f32
+
+  // CHECK-NEXT: spv.ReturnValue [[RES]]
+  spv.ReturnValue %2 : f32
+}
+
+// -----
+
+// CHECK: func @clamp_fordlessthanequal(%[[INPUT:.*]]: f32)
+func @clamp_fordlessthanequal(%input: f32) -> f32 {
+  // CHECK: %[[MIN:.*]] = spv.constant
+  %min = spv.constant 0.5 : f32
+  // CHECK: %[[MAX:.*]] = spv.constant
+  %max = spv.constant 1.0 : f32
+
+  // CHECK: [[RES:%.*]] = spv.GLSL.FClamp %[[INPUT]], %[[MIN]], %[[MAX]]
+  %0 = spv.FOrdLessThanEqual %min, %input : f32
+  %mid = spv.Select %0, %input, %min : i1, f32
+  %1 = spv.FOrdLessThanEqual %mid, %max : f32
+  %2 = spv.Select %1, %mid, %max : i1, f32
+
+  // CHECK-NEXT: spv.ReturnValue [[RES]]
+  spv.ReturnValue %2 : f32
+}
+
+// -----
+
+// CHECK: func @clamp_slessthan(%[[INPUT:.*]]: si32)
+func @clamp_slessthan(%input: si32) -> si32 {
+  // CHECK: %[[MIN:.*]] = spv.constant
+  %min = spv.constant 0 : si32
+  // CHECK: %[[MAX:.*]] = spv.constant
+  %max = spv.constant 10 : si32
+
+  // CHECK: [[RES:%.*]] = spv.GLSL.SClamp %[[INPUT]], %[[MIN]], %[[MAX]]
+  %0 = spv.SLessThan %min, %input : si32
+  %mid = spv.Select %0, %input, %min : i1, si32
+  %1 = spv.SLessThan %mid, %max : si32
+  %2 = spv.Select %1, %mid, %max : i1, si32
+
+  // CHECK-NEXT: spv.ReturnValue [[RES]]
+  spv.ReturnValue %2 : si32
+}
+
+// -----
+
+// CHECK: func @clamp_slessthanequal(%[[INPUT:.*]]: si32)
+func @clamp_slessthanequal(%input: si32) -> si32 {
+  // CHECK: %[[MIN:.*]] = spv.constant
+  %min = spv.constant 0 : si32
+  // CHECK: %[[MAX:.*]] = spv.constant
+  %max = spv.constant 10 : si32
+
+  // CHECK: [[RES:%.*]] = spv.GLSL.SClamp %[[INPUT]], %[[MIN]], %[[MAX]]
+  %0 = spv.SLessThanEqual %min, %input : si32
+  %mid = spv.Select %0, %input, %min : i1, si32
+  %1 = spv.SLessThanEqual %mid, %max : si32
+  %2 = spv.Select %1, %mid, %max : i1, si32
+
+  // CHECK-NEXT: spv.ReturnValue [[RES]]
+  spv.ReturnValue %2 : si32
+}
+
+// -----
+
+// CHECK: func @clamp_ulessthan(%[[INPUT:.*]]: i32)
+func @clamp_ulessthan(%input: i32) -> i32 {
+  // CHECK: %[[MIN:.*]] = spv.constant
+  %min = spv.constant 0 : i32
+  // CHECK: %[[MAX:.*]] = spv.constant
+  %max = spv.constant 10 : i32
+
+  // CHECK: [[RES:%.*]] = spv.GLSL.UClamp %[[INPUT]], %[[MIN]], %[[MAX]]
+  %0 = spv.ULessThan %min, %input : i32
+  %mid = spv.Select %0, %input, %min : i1, i32
+  %1 = spv.ULessThan %mid, %max : i32
+  %2 = spv.Select %1, %mid, %max : i1, i32
+
+  // CHECK-NEXT: spv.ReturnValue [[RES]]
+  spv.ReturnValue %2 : i32
+}
+
+// -----
+
+// CHECK: func @clamp_ulessthanequal(%[[INPUT:.*]]: i32)
+func @clamp_ulessthanequal(%input: i32) -> i32 {
+  // CHECK: %[[MIN:.*]] = spv.constant
+  %min = spv.constant 0 : i32
+  // CHECK: %[[MAX:.*]] = spv.constant
+  %max = spv.constant 10 : i32
+
+  // CHECK: [[RES:%.*]] = spv.GLSL.UClamp %[[INPUT]], %[[MIN]], %[[MAX]]
+  %0 = spv.ULessThanEqual %min, %input : i32
+  %mid = spv.Select %0, %input, %min : i1, i32
+  %1 = spv.ULessThanEqual %mid, %max : i32
+  %2 = spv.Select %1, %mid, %max : i1, i32
+
+  // CHECK-NEXT: spv.ReturnValue [[RES]]
+  spv.ReturnValue %2 : i32
+}

diff  --git a/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt
index edcbf4ebf1de..856e5eb7f40d 100644
--- a/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt
@@ -2,6 +2,7 @@
 add_mlir_library(MLIRSPIRVTestPasses
   TestAvailability.cpp
   TestEntryPointAbi.cpp
+  TestGLSLCanonicalization.cpp
   TestModuleCombiner.cpp
 
   EXCLUDE_FROM_LIBMLIR

diff  --git a/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp b/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp
new file mode 100644
index 000000000000..158601fbc17a
--- /dev/null
+++ b/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp
@@ -0,0 +1,39 @@
+//===- TestGLSLCanonicalization.cpp - Pass to test GLSL-specific pattterns ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVModule.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+class TestGLSLCanonicalizationPass
+    : public PassWrapper<TestGLSLCanonicalizationPass,
+                         OperationPass<mlir::ModuleOp>> {
+public:
+  TestGLSLCanonicalizationPass() = default;
+  TestGLSLCanonicalizationPass(const TestGLSLCanonicalizationPass &) {}
+  void runOnOperation() override;
+};
+} // namespace
+
+void TestGLSLCanonicalizationPass::runOnOperation() {
+  OwningRewritePatternList patterns;
+  spirv::populateSPIRVGLSLCanonicalizationPatterns(patterns, &getContext());
+  applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+namespace mlir {
+void registerTestSpirvGLSLCanonicalizationPass() {
+  PassRegistration<TestGLSLCanonicalizationPass> registration(
+      "test-spirv-glsl-canonicalization",
+      "Tests SPIR-V canonicalization patterns for GLSL extension.");
+}
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 67aa855092ef..dc68f8f4d778 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -47,6 +47,7 @@ void registerTestPrintDefUsePass();
 void registerTestPrintNestingPass();
 void registerTestReducer();
 void registerTestSpirvEntryPointABIPass();
+void registerTestSpirvGLSLCanonicalizationPass();
 void registerTestSpirvModuleCombinerPass();
 void registerTestTraitsPass();
 void registerTosaTestQuantUtilAPIPass();
@@ -115,6 +116,7 @@ void registerTestPasses() {
   registerTestPrintNestingPass();
   registerTestReducer();
   registerTestSpirvEntryPointABIPass();
+  registerTestSpirvGLSLCanonicalizationPass();
   registerTestSpirvModuleCombinerPass();
   registerTestTraitsPass();
   registerVectorizerTestPass();


        


More information about the llvm-branch-commits mailing list