[Mlir-commits] [mlir] [mlir][spirv] Add folding for SelectOp (PR #85430)
Jakub Kuderski
llvmlistbot at llvm.org
Mon Mar 18 00:39:20 PDT 2024
================
@@ -797,6 +797,85 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
return Attribute();
}
+//===----------------------------------------------------------------------===//
+// spirv.SelectOp
+//===----------------------------------------------------------------------===//
+
+template <class AttrElementT,
+ class ElementValueT = typename AttrElementT::ValueType>
+static Attribute foldSelections(const ElementsAttr &condAttrs,
+ const ElementsAttr &trueAttrs,
+ const ElementsAttr &falseAttrs) {
+ auto condsIt = condAttrs.value_begin<BoolAttr>();
+ auto trueAttrsIt = trueAttrs.value_begin<ElementValueT>();
+ auto falseAttrsIt = falseAttrs.value_begin<ElementValueT>();
+
+ SmallVector<ElementValueT, 4> elementResults;
+ elementResults.reserve(condAttrs.getNumElements());
+ for (size_t i = 0, e = condAttrs.getNumElements(); i < e;
+ ++i, ++condsIt, ++trueAttrsIt, ++falseAttrsIt) {
+ if ((*condsIt).getValue()) // If Condition then take Object 1
+ elementResults.push_back(*trueAttrsIt);
+ else // Else take Object 2
+ elementResults.push_back(*falseAttrsIt);
+ }
+
+ auto resultType = trueAttrs.getType();
+ return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
+}
+
+static Attribute foldSelectOp(llvm::ArrayRef<Attribute> operands) {
+ auto condAttrs = llvm::dyn_cast<ElementsAttr>(operands[0]);
+ auto trueAttrs = llvm::dyn_cast<ElementsAttr>(operands[1]);
+ auto falseAttrs = llvm::dyn_cast<ElementsAttr>(operands[2]);
+ if (!condAttrs || !trueAttrs || !falseAttrs)
+ return Attribute();
+
+ // According to the SPIR-V spec:
+ //
+ // If Condition is a vector, Result Type must be a vector with the same
+ // number of components as Condition and the result is a mix of Object 1
+ // and Object 2: When a component of Condition is true, the corresponding
+ // component in the result is taken from Object 1, otherwise it is taken
+ // from Object 2.
+ auto elementType = trueAttrs.getElementType();
+ if (trueAttrs.getType() != falseAttrs.getType() ||
+ !condAttrs.getElementType().isInteger(1))
+ return Attribute();
+
+ if (llvm::isa<IntegerType>(elementType)) {
+ return foldSelections<IntegerAttr>(condAttrs, trueAttrs, falseAttrs);
+ } else if (llvm::isa<FloatType>(elementType)) {
+ return foldSelections<FloatAttr>(condAttrs, trueAttrs, falseAttrs);
+ }
+
+ return Attribute();
+}
+
+OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
+ // spirv.Select _ x x -> x
+ auto trueVals = getOperand(1);
+ auto falseVals = getOperand(2);
----------------
kuhar wrote:
I think this should have named getters like `getCondition()`, `getTrueValue()`, no?
https://github.com/llvm/llvm-project/blob/65ae09eeb6773b14189fc67051870c8fc4eb9ae3/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td#L787-L789
https://github.com/llvm/llvm-project/pull/85430
More information about the Mlir-commits
mailing list