[Mlir-commits] [mlir] [mlir][spirv] Add folding for SelectOp (PR #85430)

Finn Plummer llvmlistbot at llvm.org
Fri Mar 15 09:59:04 PDT 2024


https://github.com/inbelic created https://github.com/llvm/llvm-project/pull/85430

Add missing constant propogation folder for spirv.Select

Implement additional folding when both selections are equivalent or the condition is a constant Scalar/SplatVector.

Allows for constant folding in the IndexToSPIRV pass.

Part of work #70704

>From 4998cc1f0c3e0aaa23dde9baf86b04be90978ba9 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 12 Mar 2024 11:07:42 -0700
Subject: [PATCH] [mlir][spirv] Add folding for SelectOp

Add missing constant propogation folder for spirv.Select

Implement additional folding when both selections are equivalent or the
condition is a constant Scalar/SplatVector.

Allows for constant folding in the IndexToSPIRV pass.

Part of work #70704
---
 .../mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td  |  2 +
 .../SPIRV/IR/SPIRVCanonicalization.cpp        | 79 +++++++++++++++++++
 .../SPIRVToLLVM/misc-ops-to-llvm.mlir         | 10 +--
 .../SPIRV/Transforms/canonicalize.mlir        | 46 +++++++++++
 4 files changed, 132 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index e48a56f0625d3f..3ee239d6e1e3ec 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -800,6 +800,8 @@ def SPIRV_SelectOp : SPIRV_Op<"Select",
   // These ops require dynamic availability specification based on operand and
   // result types.
   bit autogenAvailability = 0;
+
+  let hasFolder = 1;
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 4c62289a1e9458..ceb929e7e09808 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -797,6 +797,85 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
   return Attribute();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.SelectOp
+//===----------------------------------------------------------------------===//
+
+template <class AttrElementT,
+          class ElementValueT = typename AttrElementT::ValueType>
+static Attribute foldSelections(const ElementsAttr &condAttrs,
+                                const ElementsAttr &trueAttrs,
+                                const ElementsAttr &falseAttrs) {
+  auto condsIt = condAttrs.value_begin<BoolAttr>();
+  auto trueAttrsIt = trueAttrs.value_begin<ElementValueT>();
+  auto falseAttrsIt = falseAttrs.value_begin<ElementValueT>();
+
+  SmallVector<ElementValueT, 4> elementResults;
+  elementResults.reserve(condAttrs.getNumElements());
+  for (size_t i = 0, e = condAttrs.getNumElements(); i < e;
+       ++i, ++condsIt, ++trueAttrsIt, ++falseAttrsIt) {
+    if ((*condsIt).getValue()) // If Condition then take Object 1
+      elementResults.push_back(*trueAttrsIt);
+    else // Else take Object 2
+      elementResults.push_back(*falseAttrsIt);
+  }
+
+  auto resultType = trueAttrs.getType();
+  return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
+}
+
+static Attribute foldSelectOp(llvm::ArrayRef<Attribute> operands) {
+  auto condAttrs = llvm::dyn_cast<ElementsAttr>(operands[0]);
+  auto trueAttrs = llvm::dyn_cast<ElementsAttr>(operands[1]);
+  auto falseAttrs = llvm::dyn_cast<ElementsAttr>(operands[2]);
+  if (!condAttrs || !trueAttrs || !falseAttrs)
+    return Attribute();
+
+  // According to the SPIR-V spec:
+  //
+  // If Condition is a vector, Result Type must be a vector with the same
+  // number of components as Condition and the result is a mix of Object 1
+  // and Object 2: When a component of Condition is true, the corresponding
+  // component in the result is taken from Object 1, otherwise it is taken
+  // from Object 2.
+  auto elementType = trueAttrs.getElementType();
+  if (trueAttrs.getType() != falseAttrs.getType() ||
+      !condAttrs.getElementType().isInteger(1))
+    return Attribute();
+
+  if (llvm::isa<IntegerType>(elementType)) {
+    return foldSelections<IntegerAttr>(condAttrs, trueAttrs, falseAttrs);
+  } else if (llvm::isa<FloatType>(elementType)) {
+    return foldSelections<FloatAttr>(condAttrs, trueAttrs, falseAttrs);
+  }
+
+  return Attribute();
+}
+
+OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
+  // spirv.Select _ x x -> x
+  auto trueVals = getOperand(1);
+  auto falseVals = getOperand(2);
+  if (trueVals == falseVals)
+    return trueVals;
+
+  auto operands = adaptor.getOperands();
+
+  // spirv.Select true  x y -> x
+  // spirv.Select false x y -> y
+  if (auto boolAttr = getScalarOrSplatBoolAttr(operands[0]))
+    return *boolAttr ? trueVals : falseVals;
+
+  // Check that all the operands are constant
+  if (!operands[0] || !operands[1] || !operands[2])
+    return Attribute();
+
+  // Note: getScalarOrSplatBoolAttr will always return a boolAttr if we are in
+  // the scalar case. Hence, we are only required to consider the case of
+  // ElementsAttr in foldSelectOp.
+  return foldSelectOp(operands);
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.IEqualOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir
index 9fe1e532dfc77c..31da59dcdc7260 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir
@@ -43,18 +43,18 @@ spirv.func @composite_insert_vector(%arg0: vector<3xf32>, %arg1: f32) "None" {
 //===----------------------------------------------------------------------===//
 
 // CHECK-LABEL: @select_scalar
-spirv.func @select_scalar(%arg0: i1, %arg1: vector<3xi32>, %arg2: f32) "None" {
+spirv.func @select_scalar(%arg0: i1, %arg1: vector<3xi32>, %arg2: vector<3xi32>, %arg3: f32, %arg4: f32) "None" {
   // CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : i1, vector<3xi32>
-  %0 = spirv.Select %arg0, %arg1, %arg1 : i1, vector<3xi32>
+  %0 = spirv.Select %arg0, %arg1, %arg2 : i1, vector<3xi32>
   // CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : i1, f32
-  %1 = spirv.Select %arg0, %arg2, %arg2 : i1, f32
+  %1 = spirv.Select %arg0, %arg3, %arg4 : i1, f32
   spirv.Return
 }
 
 // CHECK-LABEL: @select_vector
-spirv.func @select_vector(%arg0: vector<2xi1>, %arg1: vector<2xi32>) "None" {
+spirv.func @select_vector(%arg0: vector<2xi1>, %arg1: vector<2xi32>, %arg2: vector<2xi32>) "None" {
   // CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : vector<2xi1>, vector<2xi32>
-  %0 = spirv.Select %arg0, %arg1, %arg1 : vector<2xi1>, vector<2xi32>
+  %0 = spirv.Select %arg0, %arg1, %arg2 : vector<2xi1>, vector<2xi32>
   spirv.Return
 }
 
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 1cb69891a70ed6..de21d114e9fc4f 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1346,6 +1346,52 @@ func.func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.Select
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @convert_select_scalar
+// CHECK-SAME: %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32
+func.func @convert_select_scalar(%arg1: i32, %arg2: i32) -> (i32, i32) {
+  %true = spirv.Constant true
+  %false = spirv.Constant false
+  %0 = spirv.Select %true, %arg1, %arg2 : i1, i32
+  %1 = spirv.Select %false, %arg1, %arg2 : i1, i32
+
+  // CHECK: return %[[ARG1]], %[[ARG2]]
+  return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @convert_select_vector
+// CHECK-SAME: %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<3xi32>
+func.func @convert_select_vector(%arg1: vector<3xi32>, %arg2: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>) {
+  %true = spirv.Constant dense<true> : vector<3xi1>
+  %false = spirv.Constant dense<false> : vector<3xi1>
+  %0 = spirv.Select %true, %arg1, %arg2 : vector<3xi1>, vector<3xi32>
+  %1 = spirv.Select %false, %arg1, %arg2 : vector<3xi1>, vector<3xi32>
+
+  // CHECK: return %[[ARG1]], %[[ARG2]]
+  return %0, %1: vector<3xi32>, vector<3xi32>
+}
+
+// CHECK-LABEL: @convert_select_vector_extra
+// CHECK-SAME: %[[CONDITIONS:.+]]: vector<2xi1>, %[[ARG1:.+]]: vector<2xi32>
+func.func @convert_select_vector_extra(%conditions: vector<2xi1>, %arg1: vector<2xi32>) -> (vector<2xi32>, vector<2xi32>) {
+  %true_false = spirv.Constant dense<[true, false]> : vector<2xi1>
+  %cvec_1 = spirv.Constant dense<[42, -132]> : vector<2xi32>
+  %cvec_2 = spirv.Constant dense<[0, 42]> : vector<2xi32>
+
+  // CHECK: %[[RES:.+]] = spirv.Constant dense<42>
+  %0 = spirv.Select %true_false, %cvec_1, %cvec_2: vector<2xi1>, vector<2xi32>
+
+  %1 = spirv.Select %conditions, %arg1, %arg1 : vector<2xi1>, vector<2xi32>
+
+  // CHECK: return %[[RES]], %[[ARG1]]
+  return %0, %1: vector<2xi32>, vector<2xi32>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.IEqual
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list