[Mlir-commits] [mlir] [mlir][spirv] Add folding for SelectOp (PR #85430)
Finn Plummer
llvmlistbot at llvm.org
Tue Mar 19 12:18:27 PDT 2024
https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/85430
>From 4998cc1f0c3e0aaa23dde9baf86b04be90978ba9 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 12 Mar 2024 11:07:42 -0700
Subject: [PATCH 1/3] [mlir][spirv] Add folding for SelectOp
Add missing constant propogation folder for spirv.Select
Implement additional folding when both selections are equivalent or the
condition is a constant Scalar/SplatVector.
Allows for constant folding in the IndexToSPIRV pass.
Part of work #70704
---
.../mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td | 2 +
.../SPIRV/IR/SPIRVCanonicalization.cpp | 79 +++++++++++++++++++
.../SPIRVToLLVM/misc-ops-to-llvm.mlir | 10 +--
.../SPIRV/Transforms/canonicalize.mlir | 46 +++++++++++
4 files changed, 132 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index e48a56f0625d3f..3ee239d6e1e3ec 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -800,6 +800,8 @@ def SPIRV_SelectOp : SPIRV_Op<"Select",
// These ops require dynamic availability specification based on operand and
// result types.
bit autogenAvailability = 0;
+
+ let hasFolder = 1;
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 4c62289a1e9458..ceb929e7e09808 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -797,6 +797,85 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
return Attribute();
}
+//===----------------------------------------------------------------------===//
+// spirv.SelectOp
+//===----------------------------------------------------------------------===//
+
+template <class AttrElementT,
+ class ElementValueT = typename AttrElementT::ValueType>
+static Attribute foldSelections(const ElementsAttr &condAttrs,
+ const ElementsAttr &trueAttrs,
+ const ElementsAttr &falseAttrs) {
+ auto condsIt = condAttrs.value_begin<BoolAttr>();
+ auto trueAttrsIt = trueAttrs.value_begin<ElementValueT>();
+ auto falseAttrsIt = falseAttrs.value_begin<ElementValueT>();
+
+ SmallVector<ElementValueT, 4> elementResults;
+ elementResults.reserve(condAttrs.getNumElements());
+ for (size_t i = 0, e = condAttrs.getNumElements(); i < e;
+ ++i, ++condsIt, ++trueAttrsIt, ++falseAttrsIt) {
+ if ((*condsIt).getValue()) // If Condition then take Object 1
+ elementResults.push_back(*trueAttrsIt);
+ else // Else take Object 2
+ elementResults.push_back(*falseAttrsIt);
+ }
+
+ auto resultType = trueAttrs.getType();
+ return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
+}
+
+static Attribute foldSelectOp(llvm::ArrayRef<Attribute> operands) {
+ auto condAttrs = llvm::dyn_cast<ElementsAttr>(operands[0]);
+ auto trueAttrs = llvm::dyn_cast<ElementsAttr>(operands[1]);
+ auto falseAttrs = llvm::dyn_cast<ElementsAttr>(operands[2]);
+ if (!condAttrs || !trueAttrs || !falseAttrs)
+ return Attribute();
+
+ // According to the SPIR-V spec:
+ //
+ // If Condition is a vector, Result Type must be a vector with the same
+ // number of components as Condition and the result is a mix of Object 1
+ // and Object 2: When a component of Condition is true, the corresponding
+ // component in the result is taken from Object 1, otherwise it is taken
+ // from Object 2.
+ auto elementType = trueAttrs.getElementType();
+ if (trueAttrs.getType() != falseAttrs.getType() ||
+ !condAttrs.getElementType().isInteger(1))
+ return Attribute();
+
+ if (llvm::isa<IntegerType>(elementType)) {
+ return foldSelections<IntegerAttr>(condAttrs, trueAttrs, falseAttrs);
+ } else if (llvm::isa<FloatType>(elementType)) {
+ return foldSelections<FloatAttr>(condAttrs, trueAttrs, falseAttrs);
+ }
+
+ return Attribute();
+}
+
+OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
+ // spirv.Select _ x x -> x
+ auto trueVals = getOperand(1);
+ auto falseVals = getOperand(2);
+ if (trueVals == falseVals)
+ return trueVals;
+
+ auto operands = adaptor.getOperands();
+
+ // spirv.Select true x y -> x
+ // spirv.Select false x y -> y
+ if (auto boolAttr = getScalarOrSplatBoolAttr(operands[0]))
+ return *boolAttr ? trueVals : falseVals;
+
+ // Check that all the operands are constant
+ if (!operands[0] || !operands[1] || !operands[2])
+ return Attribute();
+
+ // Note: getScalarOrSplatBoolAttr will always return a boolAttr if we are in
+ // the scalar case. Hence, we are only required to consider the case of
+ // ElementsAttr in foldSelectOp.
+ return foldSelectOp(operands);
+}
+
//===----------------------------------------------------------------------===//
// spirv.IEqualOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir
index 9fe1e532dfc77c..31da59dcdc7260 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir
@@ -43,18 +43,18 @@ spirv.func @composite_insert_vector(%arg0: vector<3xf32>, %arg1: f32) "None" {
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @select_scalar
-spirv.func @select_scalar(%arg0: i1, %arg1: vector<3xi32>, %arg2: f32) "None" {
+spirv.func @select_scalar(%arg0: i1, %arg1: vector<3xi32>, %arg2: vector<3xi32>, %arg3: f32, %arg4: f32) "None" {
// CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : i1, vector<3xi32>
- %0 = spirv.Select %arg0, %arg1, %arg1 : i1, vector<3xi32>
+ %0 = spirv.Select %arg0, %arg1, %arg2 : i1, vector<3xi32>
// CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : i1, f32
- %1 = spirv.Select %arg0, %arg2, %arg2 : i1, f32
+ %1 = spirv.Select %arg0, %arg3, %arg4 : i1, f32
spirv.Return
}
// CHECK-LABEL: @select_vector
-spirv.func @select_vector(%arg0: vector<2xi1>, %arg1: vector<2xi32>) "None" {
+spirv.func @select_vector(%arg0: vector<2xi1>, %arg1: vector<2xi32>, %arg2: vector<2xi32>) "None" {
// CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : vector<2xi1>, vector<2xi32>
- %0 = spirv.Select %arg0, %arg1, %arg1 : vector<2xi1>, vector<2xi32>
+ %0 = spirv.Select %arg0, %arg1, %arg2 : vector<2xi1>, vector<2xi32>
spirv.Return
}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 1cb69891a70ed6..de21d114e9fc4f 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1346,6 +1346,52 @@ func.func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3
// -----
+//===----------------------------------------------------------------------===//
+// spirv.Select
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @convert_select_scalar
+// CHECK-SAME: %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32
+func.func @convert_select_scalar(%arg1: i32, %arg2: i32) -> (i32, i32) {
+ %true = spirv.Constant true
+ %false = spirv.Constant false
+ %0 = spirv.Select %true, %arg1, %arg2 : i1, i32
+ %1 = spirv.Select %false, %arg1, %arg2 : i1, i32
+
+ // CHECK: return %[[ARG1]], %[[ARG2]]
+ return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @convert_select_vector
+// CHECK-SAME: %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<3xi32>
+func.func @convert_select_vector(%arg1: vector<3xi32>, %arg2: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>) {
+ %true = spirv.Constant dense<true> : vector<3xi1>
+ %false = spirv.Constant dense<false> : vector<3xi1>
+ %0 = spirv.Select %true, %arg1, %arg2 : vector<3xi1>, vector<3xi32>
+ %1 = spirv.Select %false, %arg1, %arg2 : vector<3xi1>, vector<3xi32>
+
+ // CHECK: return %[[ARG1]], %[[ARG2]]
+ return %0, %1: vector<3xi32>, vector<3xi32>
+}
+
+// CHECK-LABEL: @convert_select_vector_extra
+// CHECK-SAME: %[[CONDITIONS:.+]]: vector<2xi1>, %[[ARG1:.+]]: vector<2xi32>
+func.func @convert_select_vector_extra(%conditions: vector<2xi1>, %arg1: vector<2xi32>) -> (vector<2xi32>, vector<2xi32>) {
+ %true_false = spirv.Constant dense<[true, false]> : vector<2xi1>
+ %cvec_1 = spirv.Constant dense<[42, -132]> : vector<2xi32>
+ %cvec_2 = spirv.Constant dense<[0, 42]> : vector<2xi32>
+
+ // CHECK: %[[RES:.+]] = spirv.Constant dense<42>
+ %0 = spirv.Select %true_false, %cvec_1, %cvec_2: vector<2xi1>, vector<2xi32>
+
+ %1 = spirv.Select %conditions, %arg1, %arg1 : vector<2xi1>, vector<2xi32>
+
+ // CHECK: return %[[RES]], %[[ARG1]]
+ return %0, %1: vector<2xi32>, vector<2xi32>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.IEqual
//===----------------------------------------------------------------------===//
>From 7785e59a3d1d1475e0545b37674ff5f652049270 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 19 Mar 2024 11:02:36 -0700
Subject: [PATCH 2/3] review comments
- remove use of auto when type is unclear
- simplify the folding as we don't need to seperate the cases
- remove unneeded llvm namespacing
- use operand specific getters for readability
---
.../SPIRV/IR/SPIRVCanonicalization.cpp | 54 +++++--------------
1 file changed, 14 insertions(+), 40 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index ceb929e7e09808..ba99352fadcbfd 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -801,16 +801,18 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
// spirv.SelectOp
//===----------------------------------------------------------------------===//
-template <class AttrElementT,
- class ElementValueT = typename AttrElementT::ValueType>
-static Attribute foldSelections(const ElementsAttr &condAttrs,
- const ElementsAttr &trueAttrs,
- const ElementsAttr &falseAttrs) {
+static Attribute foldSelectOp(ArrayRef<Attribute> operands) {
+ auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
+ auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
+ auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
+ if (!condAttrs || !trueAttrs || !falseAttrs)
+ return Attribute();
+
auto condsIt = condAttrs.value_begin<BoolAttr>();
- auto trueAttrsIt = trueAttrs.value_begin<ElementValueT>();
- auto falseAttrsIt = falseAttrs.value_begin<ElementValueT>();
+ auto trueAttrsIt = trueAttrs.value_begin<Attribute>();
+ auto falseAttrsIt = falseAttrs.value_begin<Attribute>();
- SmallVector<ElementValueT, 4> elementResults;
+ SmallVector<Attribute, 4> elementResults;
elementResults.reserve(condAttrs.getNumElements());
for (size_t i = 0, e = condAttrs.getNumElements(); i < e;
++i, ++condsIt, ++trueAttrsIt, ++falseAttrsIt) {
@@ -824,42 +826,14 @@ static Attribute foldSelections(const ElementsAttr &condAttrs,
return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
}
-static Attribute foldSelectOp(llvm::ArrayRef<Attribute> operands) {
- auto condAttrs = llvm::dyn_cast<ElementsAttr>(operands[0]);
- auto trueAttrs = llvm::dyn_cast<ElementsAttr>(operands[1]);
- auto falseAttrs = llvm::dyn_cast<ElementsAttr>(operands[2]);
- if (!condAttrs || !trueAttrs || !falseAttrs)
- return Attribute();
-
- // According to the SPIR-V spec:
- //
- // If Condition is a vector, Result Type must be a vector with the same
- // number of components as Condition and the result is a mix of Object 1
- // and Object 2: When a component of Condition is true, the corresponding
- // component in the result is taken from Object 1, otherwise it is taken
- // from Object 2.
- auto elementType = trueAttrs.getElementType();
- if (trueAttrs.getType() != falseAttrs.getType() ||
- !condAttrs.getElementType().isInteger(1))
- return Attribute();
-
- if (llvm::isa<IntegerType>(elementType)) {
- return foldSelections<IntegerAttr>(condAttrs, trueAttrs, falseAttrs);
- } else if (llvm::isa<FloatType>(elementType)) {
- return foldSelections<FloatAttr>(condAttrs, trueAttrs, falseAttrs);
- }
-
- return Attribute();
-}
-
OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
// spirv.Select _ x x -> x
- auto trueVals = getOperand(1);
- auto falseVals = getOperand(2);
+ Value trueVals = getTrueValue();
+ Value falseVals = getFalseValue();
if (trueVals == falseVals)
return trueVals;
- auto operands = adaptor.getOperands();
+ ArrayRef<Attribute> operands = adaptor.getOperands();
// spirv.Select true x y -> x
// spirv.Select false x y -> y
@@ -872,7 +846,7 @@ OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
// Note: getScalarOrSplatBoolAttr will always return a boolAttr if we are in
// the scalar case. Hence, we are only required to consider the case of
- // ElementsAttr in foldSelectOp.
+ // DenseElementsAttr in foldSelectOp.
return foldSelectOp(operands);
}
>From 9539ff518e4b60d0d67488aa88e43f0bdb10fae1 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 19 Mar 2024 12:15:32 -0700
Subject: [PATCH 3/3] review comments
- use abstractions to improve loop readability
- inline due to simplicity
---
.../SPIRV/IR/SPIRVCanonicalization.cpp | 42 +++++++------------
1 file changed, 16 insertions(+), 26 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index ba99352fadcbfd..ff4bace9a4d882 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -801,31 +801,6 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
// spirv.SelectOp
//===----------------------------------------------------------------------===//
-static Attribute foldSelectOp(ArrayRef<Attribute> operands) {
- auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
- auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
- auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
- if (!condAttrs || !trueAttrs || !falseAttrs)
- return Attribute();
-
- auto condsIt = condAttrs.value_begin<BoolAttr>();
- auto trueAttrsIt = trueAttrs.value_begin<Attribute>();
- auto falseAttrsIt = falseAttrs.value_begin<Attribute>();
-
- SmallVector<Attribute, 4> elementResults;
- elementResults.reserve(condAttrs.getNumElements());
- for (size_t i = 0, e = condAttrs.getNumElements(); i < e;
- ++i, ++condsIt, ++trueAttrsIt, ++falseAttrsIt) {
- if ((*condsIt).getValue()) // If Condition then take Object 1
- elementResults.push_back(*trueAttrsIt);
- else // Else take Object 2
- elementResults.push_back(*falseAttrsIt);
- }
-
- auto resultType = trueAttrs.getType();
- return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
-}
-
OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
// spirv.Select _ x x -> x
Value trueVals = getTrueValue();
@@ -847,7 +822,22 @@ OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
// Note: getScalarOrSplatBoolAttr will always return a boolAttr if we are in
// the scalar case. Hence, we are only required to consider the case of
// DenseElementsAttr in foldSelectOp.
- return foldSelectOp(operands);
+ auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
+ auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
+ auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
+ if (!condAttrs || !trueAttrs || !falseAttrs)
+ return Attribute();
+
+ auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<Attribute>());
+ auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<BoolAttr>(),
+ falseAttrs.getValues<Attribute>());
+ for (auto [result, cond, falseRes] : iters) {
+ if (!cond.getValue())
+ result = falseRes;
+ }
+
+ auto resultType = trueAttrs.getType();
+ return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
}
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list