[Mlir-commits] [mlir] cc3b6f9 - [mlir][spirv] Add folding for [S|U|GreaterThan[Equal] (#85434)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 25 19:49:16 PDT 2024


Author: Finn Plummer
Date: 2024-03-25T19:49:12-07:00
New Revision: cc3b6f94952e1184df44c2b039d21665116151ed

URL: https://github.com/llvm/llvm-project/commit/cc3b6f94952e1184df44c2b039d21665116151ed
DIFF: https://github.com/llvm/llvm-project/commit/cc3b6f94952e1184df44c2b039d21665116151ed.diff

LOG: [mlir][spirv] Add folding for [S|U|GreaterThan[Equal] (#85434)

Add missing constant propogation folder for [S|U]GreaterThan[Equal].

Implement additional folding when the operands are equal for all ops.

Allows for constant folding in the IndexToSPIRV pass.

Part of work #70704

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
    mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index 14d639bc26f26a..61c5a7a6394f8a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -659,6 +659,8 @@ def SPIRV_SGreaterThanOp : SPIRV_LogicalBinaryOp<"SGreaterThan",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -688,6 +690,8 @@ def SPIRV_SGreaterThanEqualOp : SPIRV_LogicalBinaryOp<"SGreaterThanEqual",
     %5 = spirv.SGreaterThanEqual %2, %3 : vector<4xi32>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -834,6 +838,8 @@ def SPIRV_UGreaterThanOp : SPIRV_LogicalBinaryOp<"UGreaterThan",
     %5 = spirv.UGreaterThan %2, %3 : vector<4xi32>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -863,6 +869,8 @@ def SPIRV_UGreaterThanEqualOp : SPIRV_LogicalBinaryOp<"UGreaterThanEqual",
     %5 = spirv.UGreaterThanEqual %2, %3 : vector<4xi32>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 3f1a7826f8b63e..5f47cff71cba00 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -880,6 +880,90 @@ OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
       });
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.SGreaterThan
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {
+  // x == x -> false
+  if (getOperand1() == getOperand2()) {
+    auto falseAttr = BoolAttr::get(getContext(), false);
+    if (isa<IntegerType>(getType()))
+      return falseAttr;
+    if (auto vecTy = dyn_cast<VectorType>(getType()))
+      return SplatElementsAttr::get(vecTy, falseAttr);
+  }
+
+  return constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
+        return a.sgt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
+      });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SGreaterThanEqual
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SGreaterThanEqualOp::fold(
+    spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {
+  // x == x -> true
+  if (getOperand1() == getOperand2()) {
+    auto trueAttr = BoolAttr::get(getContext(), true);
+    if (isa<IntegerType>(getType()))
+      return trueAttr;
+    if (auto vecTy = dyn_cast<VectorType>(getType()))
+      return SplatElementsAttr::get(vecTy, trueAttr);
+  }
+
+  return constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
+        return a.sge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
+      });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UGreaterThan
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {
+  // x == x -> false
+  if (getOperand1() == getOperand2()) {
+    auto falseAttr = BoolAttr::get(getContext(), false);
+    if (isa<IntegerType>(getType()))
+      return falseAttr;
+    if (auto vecTy = dyn_cast<VectorType>(getType()))
+      return SplatElementsAttr::get(vecTy, falseAttr);
+  }
+
+  return constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
+        return a.ugt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
+      });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UGreaterThanEqual
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::UGreaterThanEqualOp::fold(
+    spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {
+  // x == x -> true
+  if (getOperand1() == getOperand2()) {
+    auto trueAttr = BoolAttr::get(getContext(), true);
+    if (isa<IntegerType>(getType()))
+      return trueAttr;
+    if (auto vecTy = dyn_cast<VectorType>(getType()))
+      return SplatElementsAttr::get(vecTy, trueAttr);
+  }
+
+  return constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
+        return a.uge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
+      });
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.SLessThan
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index ed0bd070c43c2f..977d31a6bfe542 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1478,6 +1478,182 @@ func.func @const_fold_vector_inotequal() -> vector<3xi1> {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.SGreaterThan
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @sgt_same
+func.func @sgt_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
+  // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+  // CHECK-DAG: %[[CVFALSE:.*]] = spirv.Constant dense<false>
+  %0 = spirv.SGreaterThan %arg0, %arg0 : i32
+  %1 = spirv.SGreaterThan %arg1, %arg1 : vector<3xi32>
+
+  // CHECK: return %[[CFALSE]], %[[CVFALSE]]
+  return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_sgt
+func.func @const_fold_scalar_sgt() -> (i1, i1) {
+  %c4 = spirv.Constant 4 : i32
+  %c5 = spirv.Constant 5 : i32
+  %c6 = spirv.Constant 6 : i32
+
+  // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+  // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+  %0 = spirv.SGreaterThan %c5, %c6 : i32
+  %1 = spirv.SGreaterThan %c5, %c4 : i32
+
+  // CHECK: return %[[CFALSE]], %[[CTRUE]]
+  return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_sgt
+func.func @const_fold_vector_sgt() -> vector<3xi1> {
+  %cv0 = spirv.Constant dense<[-1, -4, 3]> : vector<3xi32>
+  %cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
+
+  // CHECK: %[[RET:.*]] = spirv.Constant dense<[false, false, true]>
+  %0 = spirv.SGreaterThan %cv0, %cv1 : vector<3xi32>
+
+  // CHECK: return %[[RET]]
+  return %0 : vector<3xi1>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SGreaterThanEqual
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @sge_same
+func.func @sge_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
+  // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+  // CHECK-DAG: %[[CVTRUE:.*]] = spirv.Constant dense<true>
+  %0 = spirv.SGreaterThanEqual %arg0, %arg0 : i32
+  %1 = spirv.SGreaterThanEqual %arg1, %arg1 : vector<3xi32>
+
+  // CHECK: return %[[CTRUE]], %[[CVTRUE]]
+  return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_sge
+func.func @const_fold_scalar_sge() -> (i1, i1) {
+  %c4 = spirv.Constant 4 : i32
+  %c5 = spirv.Constant 5 : i32
+  %c6 = spirv.Constant 6 : i32
+
+  // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+  // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+  %0 = spirv.SGreaterThanEqual %c5, %c6 : i32
+  %1 = spirv.SGreaterThanEqual %c5, %c4 : i32
+
+  // CHECK: return %[[CFALSE]], %[[CTRUE]]
+  return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_sge
+func.func @const_fold_vector_sge() -> vector<3xi1> {
+  %cv0 = spirv.Constant dense<[-1, -4, 3]> : vector<3xi32>
+  %cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
+
+  // CHECK: %[[RET:.*]] = spirv.Constant dense<[true, false, true]>
+  %0 = spirv.SGreaterThanEqual %cv0, %cv1 : vector<3xi32>
+
+  // CHECK: return %[[RET]]
+  return %0 : vector<3xi1>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UGreaterThan
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @ugt_same
+func.func @ugt_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
+  // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+  // CHECK-DAG: %[[CVFALSE:.*]] = spirv.Constant dense<false>
+  %0 = spirv.UGreaterThan %arg0, %arg0 : i32
+  %1 = spirv.UGreaterThan %arg1, %arg1 : vector<3xi32>
+
+  // CHECK: return %[[CFALSE]], %[[CVFALSE]]
+  return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_ugt
+func.func @const_fold_scalar_ugt() -> (i1, i1) {
+  %c4 = spirv.Constant 4 : i32
+  %c5 = spirv.Constant 5 : i32
+  %cn6 = spirv.Constant -6 : i32
+
+  // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+  // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+  %0 = spirv.UGreaterThan %c5, %cn6 : i32
+  %1 = spirv.UGreaterThan %c5, %c4 : i32
+
+  // CHECK: return %[[CFALSE]], %[[CTRUE]]
+  return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_ugt
+func.func @const_fold_vector_ugt() -> vector<3xi1> {
+  %cv0 = spirv.Constant dense<[-1, -4, 3]> : vector<3xi32>
+  %cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
+
+  // CHECK: %[[RET:.*]] = spirv.Constant dense<[false, false, true]>
+  %0 = spirv.UGreaterThan %cv0, %cv1 : vector<3xi32>
+
+  // CHECK: return %[[RET]]
+  return %0 : vector<3xi1>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UGreaterThanEqual
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @uge_same
+func.func @uge_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
+  // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+  // CHECK-DAG: %[[CVTRUE:.*]] = spirv.Constant dense<true>
+  %0 = spirv.UGreaterThanEqual %arg0, %arg0 : i32
+  %1 = spirv.UGreaterThanEqual %arg1, %arg1 : vector<3xi32>
+
+  // CHECK: return %[[CTRUE]], %[[CVTRUE]]
+  return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_uge
+func.func @const_fold_scalar_uge() -> (i1, i1) {
+  %c4 = spirv.Constant 4 : i32
+  %c5 = spirv.Constant 5 : i32
+  %cn6 = spirv.Constant -6 : i32
+
+  // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+  // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+  %0 = spirv.UGreaterThanEqual %c5, %cn6 : i32
+  %1 = spirv.UGreaterThanEqual %c5, %c4 : i32
+
+  // CHECK: return %[[CFALSE]], %[[CTRUE]]
+  return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_uge
+func.func @const_fold_vector_uge() -> vector<3xi1> {
+  %cv0 = spirv.Constant dense<[-1, -4, 3]> : vector<3xi32>
+  %cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
+
+  // CHECK: %[[RET:.*]] = spirv.Constant dense<[true, false, true]>
+  %0 = spirv.UGreaterThanEqual %cv0, %cv1 : vector<3xi32>
+
+  // CHECK: return %[[RET]]
+  return %0 : vector<3xi1>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.SLessThan
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list