[Mlir-commits] [mlir] [mlir][spirv] Add folding for SelectOp (PR #85430)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 15 13:42:33 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir
Author: Finn Plummer (inbelic)
<details>
<summary>Changes</summary>
Add missing constant propogation folder for spirv.Select
Implement additional folding when both selections are equivalent or the condition is a constant Scalar/SplatVector.
Allows for constant folding in the IndexToSPIRV pass.
Part of work #<!-- -->70704
---
Full diff: https://github.com/llvm/llvm-project/pull/85430.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td (+2)
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp (+79)
- (modified) mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir (+5-5)
- (modified) mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir (+46)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index e48a56f0625d3f..3ee239d6e1e3ec 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -800,6 +800,8 @@ def SPIRV_SelectOp : SPIRV_Op<"Select",
// These ops require dynamic availability specification based on operand and
// result types.
bit autogenAvailability = 0;
+
+ let hasFolder = 1;
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 4c62289a1e9458..ceb929e7e09808 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -797,6 +797,85 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
return Attribute();
}
+//===----------------------------------------------------------------------===//
+// spirv.SelectOp
+//===----------------------------------------------------------------------===//
+
+template <class AttrElementT,
+ class ElementValueT = typename AttrElementT::ValueType>
+static Attribute foldSelections(const ElementsAttr &condAttrs,
+ const ElementsAttr &trueAttrs,
+ const ElementsAttr &falseAttrs) {
+ auto condsIt = condAttrs.value_begin<BoolAttr>();
+ auto trueAttrsIt = trueAttrs.value_begin<ElementValueT>();
+ auto falseAttrsIt = falseAttrs.value_begin<ElementValueT>();
+
+ SmallVector<ElementValueT, 4> elementResults;
+ elementResults.reserve(condAttrs.getNumElements());
+ for (size_t i = 0, e = condAttrs.getNumElements(); i < e;
+ ++i, ++condsIt, ++trueAttrsIt, ++falseAttrsIt) {
+ if ((*condsIt).getValue()) // If Condition then take Object 1
+ elementResults.push_back(*trueAttrsIt);
+ else // Else take Object 2
+ elementResults.push_back(*falseAttrsIt);
+ }
+
+ auto resultType = trueAttrs.getType();
+ return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
+}
+
+static Attribute foldSelectOp(llvm::ArrayRef<Attribute> operands) {
+ auto condAttrs = llvm::dyn_cast<ElementsAttr>(operands[0]);
+ auto trueAttrs = llvm::dyn_cast<ElementsAttr>(operands[1]);
+ auto falseAttrs = llvm::dyn_cast<ElementsAttr>(operands[2]);
+ if (!condAttrs || !trueAttrs || !falseAttrs)
+ return Attribute();
+
+ // According to the SPIR-V spec:
+ //
+ // If Condition is a vector, Result Type must be a vector with the same
+ // number of components as Condition and the result is a mix of Object 1
+ // and Object 2: When a component of Condition is true, the corresponding
+ // component in the result is taken from Object 1, otherwise it is taken
+ // from Object 2.
+ auto elementType = trueAttrs.getElementType();
+ if (trueAttrs.getType() != falseAttrs.getType() ||
+ !condAttrs.getElementType().isInteger(1))
+ return Attribute();
+
+ if (llvm::isa<IntegerType>(elementType)) {
+ return foldSelections<IntegerAttr>(condAttrs, trueAttrs, falseAttrs);
+ } else if (llvm::isa<FloatType>(elementType)) {
+ return foldSelections<FloatAttr>(condAttrs, trueAttrs, falseAttrs);
+ }
+
+ return Attribute();
+}
+
+OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
+ // spirv.Select _ x x -> x
+ auto trueVals = getOperand(1);
+ auto falseVals = getOperand(2);
+ if (trueVals == falseVals)
+ return trueVals;
+
+ auto operands = adaptor.getOperands();
+
+ // spirv.Select true x y -> x
+ // spirv.Select false x y -> y
+ if (auto boolAttr = getScalarOrSplatBoolAttr(operands[0]))
+ return *boolAttr ? trueVals : falseVals;
+
+ // Check that all the operands are constant
+ if (!operands[0] || !operands[1] || !operands[2])
+ return Attribute();
+
+ // Note: getScalarOrSplatBoolAttr will always return a boolAttr if we are in
+ // the scalar case. Hence, we are only required to consider the case of
+ // ElementsAttr in foldSelectOp.
+ return foldSelectOp(operands);
+}
+
//===----------------------------------------------------------------------===//
// spirv.IEqualOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir
index 9fe1e532dfc77c..31da59dcdc7260 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir
@@ -43,18 +43,18 @@ spirv.func @composite_insert_vector(%arg0: vector<3xf32>, %arg1: f32) "None" {
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @select_scalar
-spirv.func @select_scalar(%arg0: i1, %arg1: vector<3xi32>, %arg2: f32) "None" {
+spirv.func @select_scalar(%arg0: i1, %arg1: vector<3xi32>, %arg2: vector<3xi32>, %arg3: f32, %arg4: f32) "None" {
// CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : i1, vector<3xi32>
- %0 = spirv.Select %arg0, %arg1, %arg1 : i1, vector<3xi32>
+ %0 = spirv.Select %arg0, %arg1, %arg2 : i1, vector<3xi32>
// CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : i1, f32
- %1 = spirv.Select %arg0, %arg2, %arg2 : i1, f32
+ %1 = spirv.Select %arg0, %arg3, %arg4 : i1, f32
spirv.Return
}
// CHECK-LABEL: @select_vector
-spirv.func @select_vector(%arg0: vector<2xi1>, %arg1: vector<2xi32>) "None" {
+spirv.func @select_vector(%arg0: vector<2xi1>, %arg1: vector<2xi32>, %arg2: vector<2xi32>) "None" {
// CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : vector<2xi1>, vector<2xi32>
- %0 = spirv.Select %arg0, %arg1, %arg1 : vector<2xi1>, vector<2xi32>
+ %0 = spirv.Select %arg0, %arg1, %arg2 : vector<2xi1>, vector<2xi32>
spirv.Return
}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 1cb69891a70ed6..de21d114e9fc4f 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1346,6 +1346,52 @@ func.func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3
// -----
+//===----------------------------------------------------------------------===//
+// spirv.Select
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @convert_select_scalar
+// CHECK-SAME: %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32
+func.func @convert_select_scalar(%arg1: i32, %arg2: i32) -> (i32, i32) {
+ %true = spirv.Constant true
+ %false = spirv.Constant false
+ %0 = spirv.Select %true, %arg1, %arg2 : i1, i32
+ %1 = spirv.Select %false, %arg1, %arg2 : i1, i32
+
+ // CHECK: return %[[ARG1]], %[[ARG2]]
+ return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @convert_select_vector
+// CHECK-SAME: %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<3xi32>
+func.func @convert_select_vector(%arg1: vector<3xi32>, %arg2: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>) {
+ %true = spirv.Constant dense<true> : vector<3xi1>
+ %false = spirv.Constant dense<false> : vector<3xi1>
+ %0 = spirv.Select %true, %arg1, %arg2 : vector<3xi1>, vector<3xi32>
+ %1 = spirv.Select %false, %arg1, %arg2 : vector<3xi1>, vector<3xi32>
+
+ // CHECK: return %[[ARG1]], %[[ARG2]]
+ return %0, %1: vector<3xi32>, vector<3xi32>
+}
+
+// CHECK-LABEL: @convert_select_vector_extra
+// CHECK-SAME: %[[CONDITIONS:.+]]: vector<2xi1>, %[[ARG1:.+]]: vector<2xi32>
+func.func @convert_select_vector_extra(%conditions: vector<2xi1>, %arg1: vector<2xi32>) -> (vector<2xi32>, vector<2xi32>) {
+ %true_false = spirv.Constant dense<[true, false]> : vector<2xi1>
+ %cvec_1 = spirv.Constant dense<[42, -132]> : vector<2xi32>
+ %cvec_2 = spirv.Constant dense<[0, 42]> : vector<2xi32>
+
+ // CHECK: %[[RES:.+]] = spirv.Constant dense<42>
+ %0 = spirv.Select %true_false, %cvec_1, %cvec_2: vector<2xi1>, vector<2xi32>
+
+ %1 = spirv.Select %conditions, %arg1, %arg1 : vector<2xi1>, vector<2xi32>
+
+ // CHECK: return %[[RES]], %[[ARG1]]
+ return %0, %1: vector<2xi32>, vector<2xi32>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.IEqual
//===----------------------------------------------------------------------===//
``````````
</details>
https://github.com/llvm/llvm-project/pull/85430
More information about the Mlir-commits
mailing list