[Mlir-commits] [mlir] 4a10457 - [mlir][arith] Fix CmpIOP folding for vector types.

Adrian Kuegel llvmlistbot at llvm.org
Wed Dec 22 09:12:41 PST 2021


Author: Adrian Kuegel
Date: 2021-12-22T18:12:24+01:00
New Revision: 4a10457d33e92f2e024f6d024d168ddcd49c3a59

URL: https://github.com/llvm/llvm-project/commit/4a10457d33e92f2e024f6d024d168ddcd49c3a59
DIFF: https://github.com/llvm/llvm-project/commit/4a10457d33e92f2e024f6d024d168ddcd49c3a59.diff

LOG: [mlir][arith] Fix CmpIOP folding for vector types.

Previously, the folding assumed that it always operates on scalar types.

Differential Revision: https://reviews.llvm.org/D116151

Added: 
    

Modified: 
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
    mlir/test/Dialect/Arithmetic/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index aca6e4f3c27af..a413fb263775d 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -1056,13 +1056,21 @@ static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
   llvm_unreachable("unknown cmpi predicate kind");
 }
 
+static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
+  auto boolAttr = BoolAttr::get(ctx, value);
+  ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
+  if (!shapedType)
+    return boolAttr;
+  return DenseElementsAttr::get(shapedType, boolAttr);
+}
+
 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
   assert(operands.size() == 2 && "cmpi takes two operands");
 
   // cmpi(pred, x, x)
   if (getLhs() == getRhs()) {
     auto val = applyCmpPredicateToEqualOperands(getPredicate());
-    return BoolAttr::get(getContext(), val);
+    return getBoolAttribute(getType(), getContext(), val);
   }
 
   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();

diff  --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index 328ed1d028c74..96a630a248cfc 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -22,6 +22,32 @@ func @cmpi_equal_operands(%arg0: i64)
       : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
 }
 
+// Test case: Folding of comparisons with equal vector operands.
+// CHECK-LABEL: @cmpi_equal_vector_operands
+//   CHECK-DAG:   %[[T:.*]] = arith.constant dense<true>
+//   CHECK-DAG:   %[[F:.*]] = arith.constant dense<false>
+//       CHECK:   return %[[T]], %[[T]], %[[T]], %[[T]], %[[T]],
+//  CHECK-SAME:          %[[F]], %[[F]], %[[F]], %[[F]], %[[F]]
+func @cmpi_equal_vector_operands(%arg0: vector<1x8xi64>)
+    -> (vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>,
+        vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>,
+	vector<1x8xi1>, vector<1x8xi1>) {
+  %0 = arith.cmpi eq, %arg0, %arg0 : vector<1x8xi64>
+  %1 = arith.cmpi sle, %arg0, %arg0 : vector<1x8xi64>
+  %2 = arith.cmpi sge, %arg0, %arg0 : vector<1x8xi64>
+  %3 = arith.cmpi ule, %arg0, %arg0 : vector<1x8xi64>
+  %4 = arith.cmpi uge, %arg0, %arg0 : vector<1x8xi64>
+  %5 = arith.cmpi ne, %arg0, %arg0 : vector<1x8xi64>
+  %6 = arith.cmpi slt, %arg0, %arg0 : vector<1x8xi64>
+  %7 = arith.cmpi sgt, %arg0, %arg0 : vector<1x8xi64>
+  %8 = arith.cmpi ult, %arg0, %arg0 : vector<1x8xi64>
+  %9 = arith.cmpi ugt, %arg0, %arg0 : vector<1x8xi64>
+  return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9
+      : vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>,
+        vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>,
+	vector<1x8xi1>, vector<1x8xi1>
+}
+
 // -----
 
 // CHECK-LABEL: @indexCastOfSignExtend


        


More information about the Mlir-commits mailing list