[Mlir-commits] [mlir] 55a4df9 - [mlir][spirv] Handle another form of folding comparsion into clamp

Lei Zhang llvmlistbot at llvm.org
Tue Mar 8 12:55:55 PST 2022


Author: Lei Zhang
Date: 2022-03-08T15:53:22-05:00
New Revision: 55a4df9c1424948943c2095a124157472860015b

URL: https://github.com/llvm/llvm-project/commit/55a4df9c1424948943c2095a124157472860015b
DIFF: https://github.com/llvm/llvm-project/commit/55a4df9c1424948943c2095a124157472860015b.diff

LOG: [mlir][spirv] Handle another form of folding comparsion into clamp

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D121227

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp
    mlir/test/Dialect/SPIRV/Transforms/glsl-canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td
index 6e0ee4488fa14..f6ad2283c0629 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td
@@ -40,8 +40,7 @@ def ConvertLogicalNotOfLogicalNotEqual : Pat<
     (SPV_LogicalEqualOp $lhs, $rhs)>;
 
 //===----------------------------------------------------------------------===//
-// Re-write spv.Select + spv.<less_than_op> to a suitable variant of
-// spv.<glsl_clamp_op>
+// spv.Select -> spv.GLSL.*Clamp
 //===----------------------------------------------------------------------===//
 
 def ValuesAreEqual : Constraint<CPred<"$0 == $1">>;
@@ -53,7 +52,9 @@ foreach CmpClampPair = [
     [SPV_SLessThanEqualOp, SPV_GLSLSClampOp],
     [SPV_ULessThanOp, SPV_GLSLUClampOp],
     [SPV_ULessThanEqualOp, SPV_GLSLUClampOp]] in {
-def ConvertComparisonIntoClamp#CmpClampPair[0] : Pat<
+
+// Detect: $min < $input, $input < $max
+def ConvertComparisonIntoClamp1_#CmpClampPair[0] : Pat<
     (SPV_SelectOp
         (CmpClampPair[0]
             (SPV_SelectOp:$middle0
@@ -67,4 +68,16 @@ def ConvertComparisonIntoClamp#CmpClampPair[0] : Pat<
         $max),
     (CmpClampPair[1] $input, $min, $max),
     [(ValuesAreEqual $middle0, $middle1)]>;
+
+// Detect: $input < $min, $max < $input
+def ConvertComparisonIntoClamp2_#CmpClampPair[0] : Pat<
+    (SPV_SelectOp
+        (CmpClampPair[0] $max, $input),
+        $max,
+        (SPV_SelectOp
+            (CmpClampPair[0] $input, $min),
+            $min,
+            $input
+        )),
+    (CmpClampPair[1] $input, $min, $max)>;
 }

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp
index ccd4c1d920641..2ea8503909ff7 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp
@@ -23,12 +23,18 @@ namespace {
 namespace mlir {
 namespace spirv {
 void populateSPIRVGLSLCanonicalizationPatterns(RewritePatternSet &results) {
-  results.add<ConvertComparisonIntoClampSPV_FOrdLessThanOp,
-              ConvertComparisonIntoClampSPV_FOrdLessThanEqualOp,
-              ConvertComparisonIntoClampSPV_SLessThanOp,
-              ConvertComparisonIntoClampSPV_SLessThanEqualOp,
-              ConvertComparisonIntoClampSPV_ULessThanOp,
-              ConvertComparisonIntoClampSPV_ULessThanEqualOp>(
+  results.add<ConvertComparisonIntoClamp1_SPV_FOrdLessThanOp,
+              ConvertComparisonIntoClamp1_SPV_FOrdLessThanEqualOp,
+              ConvertComparisonIntoClamp1_SPV_SLessThanOp,
+              ConvertComparisonIntoClamp1_SPV_SLessThanEqualOp,
+              ConvertComparisonIntoClamp1_SPV_ULessThanOp,
+              ConvertComparisonIntoClamp1_SPV_ULessThanEqualOp,
+              ConvertComparisonIntoClamp2_SPV_FOrdLessThanOp,
+              ConvertComparisonIntoClamp2_SPV_FOrdLessThanEqualOp,
+              ConvertComparisonIntoClamp2_SPV_SLessThanOp,
+              ConvertComparisonIntoClamp2_SPV_SLessThanEqualOp,
+              ConvertComparisonIntoClamp2_SPV_ULessThanOp,
+              ConvertComparisonIntoClamp2_SPV_ULessThanEqualOp>(
       results.getContext());
 }
 } // namespace spirv

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/glsl-canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/glsl-canonicalize.mlir
index 7b0ad54a3a151..9b77a971d5a48 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/glsl-canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/glsl-canonicalize.mlir
@@ -1,12 +1,8 @@
 // RUN: mlir-opt -split-input-file -spirv-canonicalize-glsl %s | FileCheck %s
 
-// CHECK: func @clamp_fordlessthan(%[[INPUT:.*]]: f32)
-func @clamp_fordlessthan(%input: f32) -> f32 {
-  // CHECK: %[[MIN:.*]] = spv.Constant
-  %min = spv.Constant 0.5 : f32
-  // CHECK: %[[MAX:.*]] = spv.Constant
-  %max = spv.Constant 1.0 : f32
-
+// CHECK-LABEL: func @clamp_fordlessthan
+//  CHECK-SAME: (%[[INPUT:.*]]: f32, %[[MIN:.*]]: f32, %[[MAX:.*]]: f32)
+func @clamp_fordlessthan(%input: f32, %min: f32, %max: f32) -> f32 {
   // CHECK: [[RES:%.*]] = spv.GLSL.FClamp %[[INPUT]], %[[MIN]], %[[MAX]]
   %0 = spv.FOrdLessThan %min, %input : f32
   %mid = spv.Select %0, %input, %min : i1, f32
@@ -19,13 +15,24 @@ func @clamp_fordlessthan(%input: f32) -> f32 {
 
 // -----
 
-// CHECK: func @clamp_fordlessthanequal(%[[INPUT:.*]]: f32)
-func @clamp_fordlessthanequal(%input: f32) -> f32 {
-  // CHECK: %[[MIN:.*]] = spv.Constant
-  %min = spv.Constant 0.5 : f32
-  // CHECK: %[[MAX:.*]] = spv.Constant
-  %max = spv.Constant 1.0 : f32
+// CHECK-LABEL: func @clamp_fordlessthan
+//  CHECK-SAME: (%[[INPUT:.*]]: f32, %[[MIN:.*]]: f32, %[[MAX:.*]]: f32)
+func @clamp_fordlessthan(%input: f32, %min: f32, %max: f32) -> f32 {
+  // CHECK: [[RES:%.*]] = spv.GLSL.FClamp %[[INPUT]], %[[MIN]], %[[MAX]]
+  %0 = spv.FOrdLessThan %input, %min : f32
+  %mid = spv.Select %0, %min, %input : i1, f32
+  %1 = spv.FOrdLessThan %max, %input : f32
+  %2 = spv.Select %1, %max, %mid : i1, f32
+
+  // CHECK-NEXT: spv.ReturnValue [[RES]]
+  spv.ReturnValue %2 : f32
+}
 
+// -----
+
+// CHECK-LABEL: func @clamp_fordlessthanequal
+//  CHECK-SAME: (%[[INPUT:.*]]: f32, %[[MIN:.*]]: f32, %[[MAX:.*]]: f32)
+func @clamp_fordlessthanequal(%input: f32, %min: f32, %max: f32) -> f32 {
   // CHECK: [[RES:%.*]] = spv.GLSL.FClamp %[[INPUT]], %[[MIN]], %[[MAX]]
   %0 = spv.FOrdLessThanEqual %min, %input : f32
   %mid = spv.Select %0, %input, %min : i1, f32
@@ -38,13 +45,24 @@ func @clamp_fordlessthanequal(%input: f32) -> f32 {
 
 // -----
 
-// CHECK: func @clamp_slessthan(%[[INPUT:.*]]: si32)
-func @clamp_slessthan(%input: si32) -> si32 {
-  // CHECK: %[[MIN:.*]] = spv.Constant
-  %min = spv.Constant 0 : si32
-  // CHECK: %[[MAX:.*]] = spv.Constant
-  %max = spv.Constant 10 : si32
+// CHECK-LABEL: func @clamp_fordlessthanequal
+//  CHECK-SAME: (%[[INPUT:.*]]: f32, %[[MIN:.*]]: f32, %[[MAX:.*]]: f32)
+func @clamp_fordlessthanequal(%input: f32, %min: f32, %max: f32) -> f32 {
+  // CHECK: [[RES:%.*]] = spv.GLSL.FClamp %[[INPUT]], %[[MIN]], %[[MAX]]
+  %0 = spv.FOrdLessThanEqual %input, %min : f32
+  %mid = spv.Select %0, %min, %input : i1, f32
+  %1 = spv.FOrdLessThanEqual %max, %input : f32
+  %2 = spv.Select %1, %max, %mid : i1, f32
+
+  // CHECK-NEXT: spv.ReturnValue [[RES]]
+  spv.ReturnValue %2 : f32
+}
+
+// -----
 
+// CHECK-LABEL: func @clamp_slessthan
+//  CHECK-SAME: (%[[INPUT:.*]]: si32, %[[MIN:.*]]: si32, %[[MAX:.*]]: si32)
+func @clamp_slessthan(%input: si32, %min: si32, %max: si32) -> si32 {
   // CHECK: [[RES:%.*]] = spv.GLSL.SClamp %[[INPUT]], %[[MIN]], %[[MAX]]
   %0 = spv.SLessThan %min, %input : si32
   %mid = spv.Select %0, %input, %min : i1, si32
@@ -57,13 +75,24 @@ func @clamp_slessthan(%input: si32) -> si32 {
 
 // -----
 
-// CHECK: func @clamp_slessthanequal(%[[INPUT:.*]]: si32)
-func @clamp_slessthanequal(%input: si32) -> si32 {
-  // CHECK: %[[MIN:.*]] = spv.Constant
-  %min = spv.Constant 0 : si32
-  // CHECK: %[[MAX:.*]] = spv.Constant
-  %max = spv.Constant 10 : si32
+// CHECK-LABEL: func @clamp_slessthan
+//  CHECK-SAME: (%[[INPUT:.*]]: si32, %[[MIN:.*]]: si32, %[[MAX:.*]]: si32)
+func @clamp_slessthan(%input: si32, %min: si32, %max: si32) -> si32 {
+  // CHECK: [[RES:%.*]] = spv.GLSL.SClamp %[[INPUT]], %[[MIN]], %[[MAX]]
+  %0 = spv.SLessThan %input, %min : si32
+  %mid = spv.Select %0, %min, %input : i1, si32
+  %1 = spv.SLessThan %max, %input : si32
+  %2 = spv.Select %1, %max, %mid : i1, si32
+
+  // CHECK-NEXT: spv.ReturnValue [[RES]]
+  spv.ReturnValue %2 : si32
+}
+
+// -----
 
+// CHECK-LABEL: func @clamp_slessthanequal
+//  CHECK-SAME: (%[[INPUT:.*]]: si32, %[[MIN:.*]]: si32, %[[MAX:.*]]: si32)
+func @clamp_slessthanequal(%input: si32, %min: si32, %max: si32) -> si32 {
   // CHECK: [[RES:%.*]] = spv.GLSL.SClamp %[[INPUT]], %[[MIN]], %[[MAX]]
   %0 = spv.SLessThanEqual %min, %input : si32
   %mid = spv.Select %0, %input, %min : i1, si32
@@ -76,13 +105,24 @@ func @clamp_slessthanequal(%input: si32) -> si32 {
 
 // -----
 
-// CHECK: func @clamp_ulessthan(%[[INPUT:.*]]: i32)
-func @clamp_ulessthan(%input: i32) -> i32 {
-  // CHECK: %[[MIN:.*]] = spv.Constant
-  %min = spv.Constant 0 : i32
-  // CHECK: %[[MAX:.*]] = spv.Constant
-  %max = spv.Constant 10 : i32
+// CHECK-LABEL: func @clamp_slessthanequal
+//  CHECK-SAME: (%[[INPUT:.*]]: si32, %[[MIN:.*]]: si32, %[[MAX:.*]]: si32)
+func @clamp_slessthanequal(%input: si32, %min: si32, %max: si32) -> si32 {
+  // CHECK: [[RES:%.*]] = spv.GLSL.SClamp %[[INPUT]], %[[MIN]], %[[MAX]]
+  %0 = spv.SLessThanEqual %input, %min : si32
+  %mid = spv.Select %0, %min, %input : i1, si32
+  %1 = spv.SLessThanEqual %max, %input : si32
+  %2 = spv.Select %1, %max, %mid : i1, si32
+
+  // CHECK-NEXT: spv.ReturnValue [[RES]]
+  spv.ReturnValue %2 : si32
+}
 
+// -----
+
+// CHECK-LABEL: func @clamp_ulessthan
+//  CHECK-SAME: (%[[INPUT:.*]]: i32, %[[MIN:.*]]: i32, %[[MAX:.*]]: i32)
+func @clamp_ulessthan(%input: i32, %min: i32, %max: i32) -> i32 {
   // CHECK: [[RES:%.*]] = spv.GLSL.UClamp %[[INPUT]], %[[MIN]], %[[MAX]]
   %0 = spv.ULessThan %min, %input : i32
   %mid = spv.Select %0, %input, %min : i1, i32
@@ -95,13 +135,24 @@ func @clamp_ulessthan(%input: i32) -> i32 {
 
 // -----
 
-// CHECK: func @clamp_ulessthanequal(%[[INPUT:.*]]: i32)
-func @clamp_ulessthanequal(%input: i32) -> i32 {
-  // CHECK: %[[MIN:.*]] = spv.Constant
-  %min = spv.Constant 0 : i32
-  // CHECK: %[[MAX:.*]] = spv.Constant
-  %max = spv.Constant 10 : i32
+// CHECK-LABEL: func @clamp_ulessthan
+//  CHECK-SAME: (%[[INPUT:.*]]: i32, %[[MIN:.*]]: i32, %[[MAX:.*]]: i32)
+func @clamp_ulessthan(%input: i32, %min: i32, %max: i32) -> i32 {
+  // CHECK: [[RES:%.*]] = spv.GLSL.UClamp %[[INPUT]], %[[MIN]], %[[MAX]]
+  %0 = spv.ULessThan %input, %min : i32
+  %mid = spv.Select %0, %min, %input : i1, i32
+  %1 = spv.ULessThan %max, %input : i32
+  %2 = spv.Select %1, %max, %mid : i1, i32
+
+  // CHECK-NEXT: spv.ReturnValue [[RES]]
+  spv.ReturnValue %2 : i32
+}
+
+// -----
 
+// CHECK-LABEL: func @clamp_ulessthanequal
+//  CHECK-SAME: (%[[INPUT:.*]]: i32, %[[MIN:.*]]: i32, %[[MAX:.*]]: i32)
+func @clamp_ulessthanequal(%input: i32, %min: i32, %max: i32) -> i32 {
   // CHECK: [[RES:%.*]] = spv.GLSL.UClamp %[[INPUT]], %[[MIN]], %[[MAX]]
   %0 = spv.ULessThanEqual %min, %input : i32
   %mid = spv.Select %0, %input, %min : i1, i32
@@ -111,3 +162,18 @@ func @clamp_ulessthanequal(%input: i32) -> i32 {
   // CHECK-NEXT: spv.ReturnValue [[RES]]
   spv.ReturnValue %2 : i32
 }
+
+// -----
+
+// CHECK-LABEL: func @clamp_ulessthanequal
+//  CHECK-SAME: (%[[INPUT:.*]]: i32, %[[MIN:.*]]: i32, %[[MAX:.*]]: i32)
+func @clamp_ulessthanequal(%input: i32, %min: i32, %max: i32) -> i32 {
+  // CHECK: [[RES:%.*]] = spv.GLSL.UClamp %[[INPUT]], %[[MIN]], %[[MAX]]
+  %0 = spv.ULessThanEqual %input, %min : i32
+  %mid = spv.Select %0, %min, %input : i1, i32
+  %1 = spv.ULessThanEqual %max, %input : i32
+  %2 = spv.Select %1, %max, %mid : i1, i32
+
+  // CHECK-NEXT: spv.ReturnValue [[RES]]
+  spv.ReturnValue %2 : i32
+}


        


More information about the Mlir-commits mailing list