[Mlir-commits] [mlir] dc38cbc - [mlir][arith] Fold selection over constant vector conditions
Jakub Kuderski
llvmlistbot at llvm.org
Mon Feb 13 11:12:42 PST 2023
Author: Jakub Kuderski
Date: 2023-02-13T13:58:32-05:00
New Revision: dc38cbcc8b89ea0fa1c6e2be21abe104e19d8c05
URL: https://github.com/llvm/llvm-project/commit/dc38cbcc8b89ea0fa1c6e2be21abe104e19d8c05
DIFF: https://github.com/llvm/llvm-project/commit/dc38cbcc8b89ea0fa1c6e2be21abe104e19d8c05.diff
LOG: [mlir][arith] Fold selection over constant vector conditions
Also add missing tests for the scalar and splat cases.
Reviewed By: antiagainst, Mogball
Differential Revision: https://reviews.llvm.org/D143801
Added:
Modified:
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Dialect/Arith/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index b296422d98b9d..d3739f8dbae61 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -19,8 +19,9 @@
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/APSInt.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallString.h"
-#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/ADT/SmallVector.h"
using namespace mlir;
using namespace mlir::arith;
@@ -2157,6 +2158,33 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
}
}
+
+ // Constant-fold constant operands over non-splat constant condition.
+ // select %cst_vec, %cst0, %cst1 => %cst2
+ if (auto cond =
+ adaptor.getCondition().dyn_cast_or_null<DenseElementsAttr>()) {
+ if (auto lhs =
+ adaptor.getTrueValue().dyn_cast_or_null<DenseElementsAttr>()) {
+ if (auto rhs =
+ adaptor.getFalseValue().dyn_cast_or_null<DenseElementsAttr>()) {
+ SmallVector<Attribute> results;
+ results.reserve(static_cast<size_t>(cond.getNumElements()));
+ auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
+ cond.value_end<BoolAttr>());
+ auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
+ lhs.value_end<Attribute>());
+ auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
+ rhs.value_end<Attribute>());
+
+ for (auto [condVal, lhsVal, rhsVal] :
+ llvm::zip_equal(condVals, lhsVals, rhsVals))
+ results.push_back(condVal.getValue() ? lhsVal : rhsVal);
+
+ return DenseElementsAttr::get(lhs.getType(), results);
+ }
+ }
+ }
+
return nullptr;
}
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 0048954ed161c..0ee1b0ba73333 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -54,6 +54,57 @@ func.func @select_extui_i1(%arg0: i1) -> i1 {
return %res : i1
}
+// CHECK-LABEL: @select_cst_false_scalar
+// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
+// CHECK-NEXT: return %[[ARG1]]
+func.func @select_cst_false_scalar(%arg0: i32, %arg1: i32) -> i32 {
+ %false = arith.constant false
+ %res = arith.select %false, %arg0, %arg1 : i32
+ return %res : i32
+}
+
+// CHECK-LABEL: @select_cst_true_scalar
+// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
+// CHECK-NEXT: return %[[ARG0]]
+func.func @select_cst_true_scalar(%arg0: i32, %arg1: i32) -> i32 {
+ %true = arith.constant true
+ %res = arith.select %true, %arg0, %arg1 : i32
+ return %res : i32
+}
+
+// CHECK-LABEL: @select_cst_true_splat
+// CHECK: %[[A:.+]] = arith.constant dense<[1, 2, 3]> : vector<3xi32>
+// CHECK-NEXT: return %[[A]]
+func.func @select_cst_true_splat() -> vector<3xi32> {
+ %cond = arith.constant dense<true> : vector<3xi1>
+ %a = arith.constant dense<[1, 2, 3]> : vector<3xi32>
+ %b = arith.constant dense<[4, 5, 6]> : vector<3xi32>
+ %res = arith.select %cond, %a, %b : vector<3xi1>, vector<3xi32>
+ return %res : vector<3xi32>
+}
+
+// CHECK-LABEL: @select_cst_vector_i32
+// CHECK: %[[RES:.+]] = arith.constant dense<[1, 5, 3]> : vector<3xi32>
+// CHECK-NEXT: return %[[RES]]
+func.func @select_cst_vector_i32() -> vector<3xi32> {
+ %cond = arith.constant dense<[true, false, true]> : vector<3xi1>
+ %a = arith.constant dense<[1, 2, 3]> : vector<3xi32>
+ %b = arith.constant dense<[4, 5, 6]> : vector<3xi32>
+ %res = arith.select %cond, %a, %b : vector<3xi1>, vector<3xi32>
+ return %res : vector<3xi32>
+}
+
+// CHECK-LABEL: @select_cst_vector_f32
+// CHECK: %[[RES:.+]] = arith.constant dense<[4.000000e+00, 2.000000e+00, 6.000000e+00]> : vector<3xf32>
+// CHECK-NEXT: return %[[RES]]
+func.func @select_cst_vector_f32() -> vector<3xf32> {
+ %cond = arith.constant dense<[false, true, false]> : vector<3xi1>
+ %a = arith.constant dense<[1.0, 2.0, 3.0]> : vector<3xf32>
+ %b = arith.constant dense<[4.0, 5.0, 6.0]> : vector<3xf32>
+ %res = arith.select %cond, %a, %b : vector<3xi1>, vector<3xf32>
+ return %res : vector<3xf32>
+}
+
// CHECK-LABEL: @selToNot
// CHECK: %[[trueval:.+]] = arith.constant true
// CHECK: %[[res:.+]] = arith.xori %arg0, %[[trueval]] : i1
More information about the Mlir-commits
mailing list