[Mlir-commits] [mlir] 35ca649 - [mlir][arith] Don't crash when folding a & ~a -> 0 on vectors

Benjamin Kramer llvmlistbot at llvm.org
Wed Jan 25 05:09:28 PST 2023


Author: Benjamin Kramer
Date: 2023-01-25T14:08:28+01:00
New Revision: 35ca64989a75c93ea7e935ef11c3d1883c21cccd

URL: https://github.com/llvm/llvm-project/commit/35ca64989a75c93ea7e935ef11c3d1883c21cccd
DIFF: https://github.com/llvm/llvm-project/commit/35ca64989a75c93ea7e935ef11c3d1883c21cccd.diff

LOG: [mlir][arith] Don't crash when folding a & ~a -> 0 on vectors

m_Constant happily accepts vector splats, so just use the generic way of
getting a zero attribute.

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/test/Dialect/Arith/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 63febd8577369..b296422d98b9d 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -778,12 +778,12 @@ OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
   if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
                                           m_ConstantInt(&intValue))) &&
       intValue.isAllOnes())
-    return IntegerAttr::get(getType(), 0);
+    return Builder(getContext()).getZeroAttr(getType());
   /// and(not(x), x) -> 0
   if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
                                           m_ConstantInt(&intValue))) &&
       intValue.isAllOnes())
-    return IntegerAttr::get(getType(), 0);
+    return Builder(getContext()).getZeroAttr(getType());
 
   /// and(a, and(a, b)) -> and(a, b)
   if (Value result = foldAndIofAndI(*this))

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 5806c9c2b365a..0048954ed161c 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1871,6 +1871,33 @@ func.func @test_andi_not_fold_lhs(%arg0 : index) -> index {
     return %2 : index
 }
 
+// -----
+
+// CHECK-LABEL: @test_andi_not_fold_rhs_vec(
+// CHECK-SAME: %[[ARG0:[[:alnum:]]+]]
+// CHECK: %[[C:.*]] = arith.constant dense<0> : vector<2xi32>
+// CHECK: return %[[C]]
+
+func.func @test_andi_not_fold_rhs_vec(%arg0 : vector<2xi32>) -> vector<2xi32> {
+    %0 = arith.constant dense<[-1, -1]> : vector<2xi32>
+    %1 = arith.xori %arg0, %0 : vector<2xi32>
+    %2 = arith.andi %arg0, %1 : vector<2xi32>
+    return %2 : vector<2xi32>
+}
+
+
+// CHECK-LABEL: @test_andi_not_fold_lhs_vec(
+// CHECK-SAME: %[[ARG0:[[:alnum:]]+]]
+// CHECK: %[[C:.*]] = arith.constant dense<0> : vector<2xi32>
+// CHECK: return %[[C]]
+
+func.func @test_andi_not_fold_lhs_vec(%arg0 : vector<2xi32>) -> vector<2xi32> {
+    %0 = arith.constant dense<[-1, -1]> : vector<2xi32>
+    %1 = arith.xori %arg0, %0 : vector<2xi32>
+    %2 = arith.andi %1, %arg0 : vector<2xi32>
+    return %2 : vector<2xi32>
+}
+
 // -----
 /// xor(xor(x, a), a) -> x
 


        


More information about the Mlir-commits mailing list