[Mlir-commits] [mlir] [mlir][spirv] Add folding for SelectOp (PR #85430)

Finn Plummer llvmlistbot at llvm.org
Tue Mar 19 12:18:27 PDT 2024


https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/85430

>From 4998cc1f0c3e0aaa23dde9baf86b04be90978ba9 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 12 Mar 2024 11:07:42 -0700
Subject: [PATCH 1/3] [mlir][spirv] Add folding for SelectOp

Add missing constant propogation folder for spirv.Select

Implement additional folding when both selections are equivalent or the
condition is a constant Scalar/SplatVector.

Allows for constant folding in the IndexToSPIRV pass.

Part of work #70704
---
 .../mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td  |  2 +
 .../SPIRV/IR/SPIRVCanonicalization.cpp        | 79 +++++++++++++++++++
 .../SPIRVToLLVM/misc-ops-to-llvm.mlir         | 10 +--
 .../SPIRV/Transforms/canonicalize.mlir        | 46 +++++++++++
 4 files changed, 132 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index e48a56f0625d3f..3ee239d6e1e3ec 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -800,6 +800,8 @@ def SPIRV_SelectOp : SPIRV_Op<"Select",
   // These ops require dynamic availability specification based on operand and
   // result types.
   bit autogenAvailability = 0;
+
+  let hasFolder = 1;
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 4c62289a1e9458..ceb929e7e09808 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -797,6 +797,85 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
   return Attribute();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.SelectOp
+//===----------------------------------------------------------------------===//
+
+template <class AttrElementT,
+          class ElementValueT = typename AttrElementT::ValueType>
+static Attribute foldSelections(const ElementsAttr &condAttrs,
+                                const ElementsAttr &trueAttrs,
+                                const ElementsAttr &falseAttrs) {
+  auto condsIt = condAttrs.value_begin<BoolAttr>();
+  auto trueAttrsIt = trueAttrs.value_begin<ElementValueT>();
+  auto falseAttrsIt = falseAttrs.value_begin<ElementValueT>();
+
+  SmallVector<ElementValueT, 4> elementResults;
+  elementResults.reserve(condAttrs.getNumElements());
+  for (size_t i = 0, e = condAttrs.getNumElements(); i < e;
+       ++i, ++condsIt, ++trueAttrsIt, ++falseAttrsIt) {
+    if ((*condsIt).getValue()) // If Condition then take Object 1
+      elementResults.push_back(*trueAttrsIt);
+    else // Else take Object 2
+      elementResults.push_back(*falseAttrsIt);
+  }
+
+  auto resultType = trueAttrs.getType();
+  return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
+}
+
+static Attribute foldSelectOp(llvm::ArrayRef<Attribute> operands) {
+  auto condAttrs = llvm::dyn_cast<ElementsAttr>(operands[0]);
+  auto trueAttrs = llvm::dyn_cast<ElementsAttr>(operands[1]);
+  auto falseAttrs = llvm::dyn_cast<ElementsAttr>(operands[2]);
+  if (!condAttrs || !trueAttrs || !falseAttrs)
+    return Attribute();
+
+  // According to the SPIR-V spec:
+  //
+  // If Condition is a vector, Result Type must be a vector with the same
+  // number of components as Condition and the result is a mix of Object 1
+  // and Object 2: When a component of Condition is true, the corresponding
+  // component in the result is taken from Object 1, otherwise it is taken
+  // from Object 2.
+  auto elementType = trueAttrs.getElementType();
+  if (trueAttrs.getType() != falseAttrs.getType() ||
+      !condAttrs.getElementType().isInteger(1))
+    return Attribute();
+
+  if (llvm::isa<IntegerType>(elementType)) {
+    return foldSelections<IntegerAttr>(condAttrs, trueAttrs, falseAttrs);
+  } else if (llvm::isa<FloatType>(elementType)) {
+    return foldSelections<FloatAttr>(condAttrs, trueAttrs, falseAttrs);
+  }
+
+  return Attribute();
+}
+
+OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
+  // spirv.Select _ x x -> x
+  auto trueVals = getOperand(1);
+  auto falseVals = getOperand(2);
+  if (trueVals == falseVals)
+    return trueVals;
+
+  auto operands = adaptor.getOperands();
+
+  // spirv.Select true  x y -> x
+  // spirv.Select false x y -> y
+  if (auto boolAttr = getScalarOrSplatBoolAttr(operands[0]))
+    return *boolAttr ? trueVals : falseVals;
+
+  // Check that all the operands are constant
+  if (!operands[0] || !operands[1] || !operands[2])
+    return Attribute();
+
+  // Note: getScalarOrSplatBoolAttr will always return a boolAttr if we are in
+  // the scalar case. Hence, we are only required to consider the case of
+  // ElementsAttr in foldSelectOp.
+  return foldSelectOp(operands);
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.IEqualOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir
index 9fe1e532dfc77c..31da59dcdc7260 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir
@@ -43,18 +43,18 @@ spirv.func @composite_insert_vector(%arg0: vector<3xf32>, %arg1: f32) "None" {
 //===----------------------------------------------------------------------===//
 
 // CHECK-LABEL: @select_scalar
-spirv.func @select_scalar(%arg0: i1, %arg1: vector<3xi32>, %arg2: f32) "None" {
+spirv.func @select_scalar(%arg0: i1, %arg1: vector<3xi32>, %arg2: vector<3xi32>, %arg3: f32, %arg4: f32) "None" {
   // CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : i1, vector<3xi32>
-  %0 = spirv.Select %arg0, %arg1, %arg1 : i1, vector<3xi32>
+  %0 = spirv.Select %arg0, %arg1, %arg2 : i1, vector<3xi32>
   // CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : i1, f32
-  %1 = spirv.Select %arg0, %arg2, %arg2 : i1, f32
+  %1 = spirv.Select %arg0, %arg3, %arg4 : i1, f32
   spirv.Return
 }
 
 // CHECK-LABEL: @select_vector
-spirv.func @select_vector(%arg0: vector<2xi1>, %arg1: vector<2xi32>) "None" {
+spirv.func @select_vector(%arg0: vector<2xi1>, %arg1: vector<2xi32>, %arg2: vector<2xi32>) "None" {
   // CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : vector<2xi1>, vector<2xi32>
-  %0 = spirv.Select %arg0, %arg1, %arg1 : vector<2xi1>, vector<2xi32>
+  %0 = spirv.Select %arg0, %arg1, %arg2 : vector<2xi1>, vector<2xi32>
   spirv.Return
 }
 
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 1cb69891a70ed6..de21d114e9fc4f 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1346,6 +1346,52 @@ func.func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.Select
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @convert_select_scalar
+// CHECK-SAME: %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32
+func.func @convert_select_scalar(%arg1: i32, %arg2: i32) -> (i32, i32) {
+  %true = spirv.Constant true
+  %false = spirv.Constant false
+  %0 = spirv.Select %true, %arg1, %arg2 : i1, i32
+  %1 = spirv.Select %false, %arg1, %arg2 : i1, i32
+
+  // CHECK: return %[[ARG1]], %[[ARG2]]
+  return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @convert_select_vector
+// CHECK-SAME: %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<3xi32>
+func.func @convert_select_vector(%arg1: vector<3xi32>, %arg2: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>) {
+  %true = spirv.Constant dense<true> : vector<3xi1>
+  %false = spirv.Constant dense<false> : vector<3xi1>
+  %0 = spirv.Select %true, %arg1, %arg2 : vector<3xi1>, vector<3xi32>
+  %1 = spirv.Select %false, %arg1, %arg2 : vector<3xi1>, vector<3xi32>
+
+  // CHECK: return %[[ARG1]], %[[ARG2]]
+  return %0, %1: vector<3xi32>, vector<3xi32>
+}
+
+// CHECK-LABEL: @convert_select_vector_extra
+// CHECK-SAME: %[[CONDITIONS:.+]]: vector<2xi1>, %[[ARG1:.+]]: vector<2xi32>
+func.func @convert_select_vector_extra(%conditions: vector<2xi1>, %arg1: vector<2xi32>) -> (vector<2xi32>, vector<2xi32>) {
+  %true_false = spirv.Constant dense<[true, false]> : vector<2xi1>
+  %cvec_1 = spirv.Constant dense<[42, -132]> : vector<2xi32>
+  %cvec_2 = spirv.Constant dense<[0, 42]> : vector<2xi32>
+
+  // CHECK: %[[RES:.+]] = spirv.Constant dense<42>
+  %0 = spirv.Select %true_false, %cvec_1, %cvec_2: vector<2xi1>, vector<2xi32>
+
+  %1 = spirv.Select %conditions, %arg1, %arg1 : vector<2xi1>, vector<2xi32>
+
+  // CHECK: return %[[RES]], %[[ARG1]]
+  return %0, %1: vector<2xi32>, vector<2xi32>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.IEqual
 //===----------------------------------------------------------------------===//

>From 7785e59a3d1d1475e0545b37674ff5f652049270 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 19 Mar 2024 11:02:36 -0700
Subject: [PATCH 2/3] review comments

- remove use of auto when type is unclear
- simplify the folding as we don't need to seperate the cases
- remove unneeded llvm namespacing
- use operand specific getters for readability
---
 .../SPIRV/IR/SPIRVCanonicalization.cpp        | 54 +++++--------------
 1 file changed, 14 insertions(+), 40 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index ceb929e7e09808..ba99352fadcbfd 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -801,16 +801,18 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
 // spirv.SelectOp
 //===----------------------------------------------------------------------===//
 
-template <class AttrElementT,
-          class ElementValueT = typename AttrElementT::ValueType>
-static Attribute foldSelections(const ElementsAttr &condAttrs,
-                                const ElementsAttr &trueAttrs,
-                                const ElementsAttr &falseAttrs) {
+static Attribute foldSelectOp(ArrayRef<Attribute> operands) {
+  auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
+  auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
+  auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
+  if (!condAttrs || !trueAttrs || !falseAttrs)
+    return Attribute();
+
   auto condsIt = condAttrs.value_begin<BoolAttr>();
-  auto trueAttrsIt = trueAttrs.value_begin<ElementValueT>();
-  auto falseAttrsIt = falseAttrs.value_begin<ElementValueT>();
+  auto trueAttrsIt = trueAttrs.value_begin<Attribute>();
+  auto falseAttrsIt = falseAttrs.value_begin<Attribute>();
 
-  SmallVector<ElementValueT, 4> elementResults;
+  SmallVector<Attribute, 4> elementResults;
   elementResults.reserve(condAttrs.getNumElements());
   for (size_t i = 0, e = condAttrs.getNumElements(); i < e;
        ++i, ++condsIt, ++trueAttrsIt, ++falseAttrsIt) {
@@ -824,42 +826,14 @@ static Attribute foldSelections(const ElementsAttr &condAttrs,
   return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
 }
 
-static Attribute foldSelectOp(llvm::ArrayRef<Attribute> operands) {
-  auto condAttrs = llvm::dyn_cast<ElementsAttr>(operands[0]);
-  auto trueAttrs = llvm::dyn_cast<ElementsAttr>(operands[1]);
-  auto falseAttrs = llvm::dyn_cast<ElementsAttr>(operands[2]);
-  if (!condAttrs || !trueAttrs || !falseAttrs)
-    return Attribute();
-
-  // According to the SPIR-V spec:
-  //
-  // If Condition is a vector, Result Type must be a vector with the same
-  // number of components as Condition and the result is a mix of Object 1
-  // and Object 2: When a component of Condition is true, the corresponding
-  // component in the result is taken from Object 1, otherwise it is taken
-  // from Object 2.
-  auto elementType = trueAttrs.getElementType();
-  if (trueAttrs.getType() != falseAttrs.getType() ||
-      !condAttrs.getElementType().isInteger(1))
-    return Attribute();
-
-  if (llvm::isa<IntegerType>(elementType)) {
-    return foldSelections<IntegerAttr>(condAttrs, trueAttrs, falseAttrs);
-  } else if (llvm::isa<FloatType>(elementType)) {
-    return foldSelections<FloatAttr>(condAttrs, trueAttrs, falseAttrs);
-  }
-
-  return Attribute();
-}
-
 OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
   // spirv.Select _ x x -> x
-  auto trueVals = getOperand(1);
-  auto falseVals = getOperand(2);
+  Value trueVals = getTrueValue();
+  Value falseVals = getFalseValue();
   if (trueVals == falseVals)
     return trueVals;
 
-  auto operands = adaptor.getOperands();
+  ArrayRef<Attribute> operands = adaptor.getOperands();
 
   // spirv.Select true  x y -> x
   // spirv.Select false x y -> y
@@ -872,7 +846,7 @@ OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
 
   // Note: getScalarOrSplatBoolAttr will always return a boolAttr if we are in
   // the scalar case. Hence, we are only required to consider the case of
-  // ElementsAttr in foldSelectOp.
+  // DenseElementsAttr in foldSelectOp.
   return foldSelectOp(operands);
 }
 

>From 9539ff518e4b60d0d67488aa88e43f0bdb10fae1 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 19 Mar 2024 12:15:32 -0700
Subject: [PATCH 3/3] review comments

- use abstractions to improve loop readability
- inline due to simplicity
---
 .../SPIRV/IR/SPIRVCanonicalization.cpp        | 42 +++++++------------
 1 file changed, 16 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index ba99352fadcbfd..ff4bace9a4d882 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -801,31 +801,6 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
 // spirv.SelectOp
 //===----------------------------------------------------------------------===//
 
-static Attribute foldSelectOp(ArrayRef<Attribute> operands) {
-  auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
-  auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
-  auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
-  if (!condAttrs || !trueAttrs || !falseAttrs)
-    return Attribute();
-
-  auto condsIt = condAttrs.value_begin<BoolAttr>();
-  auto trueAttrsIt = trueAttrs.value_begin<Attribute>();
-  auto falseAttrsIt = falseAttrs.value_begin<Attribute>();
-
-  SmallVector<Attribute, 4> elementResults;
-  elementResults.reserve(condAttrs.getNumElements());
-  for (size_t i = 0, e = condAttrs.getNumElements(); i < e;
-       ++i, ++condsIt, ++trueAttrsIt, ++falseAttrsIt) {
-    if ((*condsIt).getValue()) // If Condition then take Object 1
-      elementResults.push_back(*trueAttrsIt);
-    else // Else take Object 2
-      elementResults.push_back(*falseAttrsIt);
-  }
-
-  auto resultType = trueAttrs.getType();
-  return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
-}
-
 OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
   // spirv.Select _ x x -> x
   Value trueVals = getTrueValue();
@@ -847,7 +822,22 @@ OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
   // Note: getScalarOrSplatBoolAttr will always return a boolAttr if we are in
   // the scalar case. Hence, we are only required to consider the case of
   // DenseElementsAttr in foldSelectOp.
-  return foldSelectOp(operands);
+  auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
+  auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
+  auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
+  if (!condAttrs || !trueAttrs || !falseAttrs)
+    return Attribute();
+
+  auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<Attribute>());
+  auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<BoolAttr>(),
+                               falseAttrs.getValues<Attribute>());
+  for (auto [result, cond, falseRes] : iters) {
+    if (!cond.getValue())
+      result = falseRes;
+  }
+
+  auto resultType = trueAttrs.getType();
+  return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
 }
 
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list