[Mlir-commits] [mlir] dc38cbc - [mlir][arith] Fold selection over constant vector conditions

Jakub Kuderski llvmlistbot at llvm.org
Mon Feb 13 11:12:42 PST 2023


Author: Jakub Kuderski
Date: 2023-02-13T13:58:32-05:00
New Revision: dc38cbcc8b89ea0fa1c6e2be21abe104e19d8c05

URL: https://github.com/llvm/llvm-project/commit/dc38cbcc8b89ea0fa1c6e2be21abe104e19d8c05
DIFF: https://github.com/llvm/llvm-project/commit/dc38cbcc8b89ea0fa1c6e2be21abe104e19d8c05.diff

LOG: [mlir][arith] Fold selection over constant vector conditions

Also add missing tests for the scalar and splat cases.

Reviewed By: antiagainst, Mogball

Differential Revision: https://reviews.llvm.org/D143801

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/test/Dialect/Arith/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index b296422d98b9d..d3739f8dbae61 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -19,8 +19,9 @@
 #include "mlir/IR/TypeUtilities.h"
 
 #include "llvm/ADT/APSInt.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallString.h"
-#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/ADT/SmallVector.h"
 
 using namespace mlir;
 using namespace mlir::arith;
@@ -2157,6 +2158,33 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
         return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
     }
   }
+
+  // Constant-fold constant operands over non-splat constant condition.
+  // select %cst_vec, %cst0, %cst1 => %cst2
+  if (auto cond =
+          adaptor.getCondition().dyn_cast_or_null<DenseElementsAttr>()) {
+    if (auto lhs =
+            adaptor.getTrueValue().dyn_cast_or_null<DenseElementsAttr>()) {
+      if (auto rhs =
+              adaptor.getFalseValue().dyn_cast_or_null<DenseElementsAttr>()) {
+        SmallVector<Attribute> results;
+        results.reserve(static_cast<size_t>(cond.getNumElements()));
+        auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
+                                         cond.value_end<BoolAttr>());
+        auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
+                                        lhs.value_end<Attribute>());
+        auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
+                                        rhs.value_end<Attribute>());
+
+        for (auto [condVal, lhsVal, rhsVal] :
+             llvm::zip_equal(condVals, lhsVals, rhsVals))
+          results.push_back(condVal.getValue() ? lhsVal : rhsVal);
+
+        return DenseElementsAttr::get(lhs.getType(), results);
+      }
+    }
+  }
+
   return nullptr;
 }
 

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 0048954ed161c..0ee1b0ba73333 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -54,6 +54,57 @@ func.func @select_extui_i1(%arg0: i1) -> i1 {
   return %res : i1
 }
 
+// CHECK-LABEL: @select_cst_false_scalar
+//  CHECK-SAME:   (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
+//  CHECK-NEXT:   return %[[ARG1]]
+func.func @select_cst_false_scalar(%arg0: i32, %arg1: i32) -> i32 {
+  %false = arith.constant false
+  %res = arith.select %false, %arg0, %arg1 : i32
+  return %res : i32
+}
+
+// CHECK-LABEL: @select_cst_true_scalar
+//  CHECK-SAME:   (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
+//  CHECK-NEXT:   return %[[ARG0]]
+func.func @select_cst_true_scalar(%arg0: i32, %arg1: i32) -> i32 {
+  %true = arith.constant true
+  %res = arith.select %true, %arg0, %arg1 : i32
+  return %res : i32
+}
+
+// CHECK-LABEL: @select_cst_true_splat
+//       CHECK:   %[[A:.+]] = arith.constant dense<[1, 2, 3]> : vector<3xi32>
+//  CHECK-NEXT:   return %[[A]]
+func.func @select_cst_true_splat() -> vector<3xi32> {
+  %cond = arith.constant dense<true> : vector<3xi1>
+  %a = arith.constant dense<[1, 2, 3]> : vector<3xi32>
+  %b = arith.constant dense<[4, 5, 6]> : vector<3xi32>
+  %res = arith.select %cond, %a, %b : vector<3xi1>, vector<3xi32>
+  return %res : vector<3xi32>
+}
+
+// CHECK-LABEL: @select_cst_vector_i32
+//       CHECK:   %[[RES:.+]] = arith.constant dense<[1, 5, 3]> : vector<3xi32>
+//  CHECK-NEXT:   return %[[RES]]
+func.func @select_cst_vector_i32() -> vector<3xi32> {
+  %cond = arith.constant dense<[true, false, true]> : vector<3xi1>
+  %a = arith.constant dense<[1, 2, 3]> : vector<3xi32>
+  %b = arith.constant dense<[4, 5, 6]> : vector<3xi32>
+  %res = arith.select %cond, %a, %b : vector<3xi1>, vector<3xi32>
+  return %res : vector<3xi32>
+}
+
+// CHECK-LABEL: @select_cst_vector_f32
+//       CHECK:   %[[RES:.+]] = arith.constant dense<[4.000000e+00, 2.000000e+00, 6.000000e+00]> : vector<3xf32>
+//  CHECK-NEXT:   return %[[RES]]
+func.func @select_cst_vector_f32() -> vector<3xf32> {
+  %cond = arith.constant dense<[false, true, false]> : vector<3xi1>
+  %a = arith.constant dense<[1.0, 2.0, 3.0]> : vector<3xf32>
+  %b = arith.constant dense<[4.0, 5.0, 6.0]> : vector<3xf32>
+  %res = arith.select %cond, %a, %b : vector<3xi1>, vector<3xf32>
+  return %res : vector<3xf32>
+}
+
 // CHECK-LABEL: @selToNot
 //       CHECK:       %[[trueval:.+]] = arith.constant true
 //       CHECK:       %[[res:.+]] = arith.xori %arg0, %[[trueval]] : i1


        


More information about the Mlir-commits mailing list