[Mlir-commits] [mlir] [mlir][spirv] Add folding for SelectOp (PR #85430)

Finn Plummer llvmlistbot at llvm.org
Tue Mar 19 11:11:03 PDT 2024


================
@@ -797,6 +797,85 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
   return Attribute();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.SelectOp
+//===----------------------------------------------------------------------===//
+
+template <class AttrElementT,
+          class ElementValueT = typename AttrElementT::ValueType>
+static Attribute foldSelections(const ElementsAttr &condAttrs,
+                                const ElementsAttr &trueAttrs,
+                                const ElementsAttr &falseAttrs) {
+  auto condsIt = condAttrs.value_begin<BoolAttr>();
+  auto trueAttrsIt = trueAttrs.value_begin<ElementValueT>();
+  auto falseAttrsIt = falseAttrs.value_begin<ElementValueT>();
+
+  SmallVector<ElementValueT, 4> elementResults;
+  elementResults.reserve(condAttrs.getNumElements());
+  for (size_t i = 0, e = condAttrs.getNumElements(); i < e;
+       ++i, ++condsIt, ++trueAttrsIt, ++falseAttrsIt) {
+    if ((*condsIt).getValue()) // If Condition then take Object 1
+      elementResults.push_back(*trueAttrsIt);
+    else // Else take Object 2
+      elementResults.push_back(*falseAttrsIt);
+  }
+
+  auto resultType = trueAttrs.getType();
+  return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
+}
+
+static Attribute foldSelectOp(llvm::ArrayRef<Attribute> operands) {
+  auto condAttrs = llvm::dyn_cast<ElementsAttr>(operands[0]);
+  auto trueAttrs = llvm::dyn_cast<ElementsAttr>(operands[1]);
+  auto falseAttrs = llvm::dyn_cast<ElementsAttr>(operands[2]);
+  if (!condAttrs || !trueAttrs || !falseAttrs)
+    return Attribute();
+
+  // According to the SPIR-V spec:
+  //
+  // If Condition is a vector, Result Type must be a vector with the same
+  // number of components as Condition and the result is a mix of Object 1
+  // and Object 2: When a component of Condition is true, the corresponding
+  // component in the result is taken from Object 1, otherwise it is taken
+  // from Object 2.
+  auto elementType = trueAttrs.getElementType();
+  if (trueAttrs.getType() != falseAttrs.getType() ||
+      !condAttrs.getElementType().isInteger(1))
+    return Attribute();
+
+  if (llvm::isa<IntegerType>(elementType)) {
+    return foldSelections<IntegerAttr>(condAttrs, trueAttrs, falseAttrs);
+  } else if (llvm::isa<FloatType>(elementType)) {
+    return foldSelections<FloatAttr>(condAttrs, trueAttrs, falseAttrs);
+  }
+
+  return Attribute();
+}
+
+OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
+  // spirv.Select _ x x -> x
+  auto trueVals = getOperand(1);
+  auto falseVals = getOperand(2);
----------------
inbelic wrote:

Yep, much nicer for readability

https://github.com/llvm/llvm-project/pull/85430


More information about the Mlir-commits mailing list