[Mlir-commits] [mlir] 38f34e5 - [mlir][Arith] Fix folder of CmpIOp to not fail when element type is not integer.
Mahesh Ravishankar
llvmlistbot at llvm.org
Thu Nov 3 13:39:41 PDT 2022
Author: Mahesh Ravishankar
Date: 2022-11-03T20:38:34Z
New Revision: 38f34e587d10fcd7d18fd240e41248006faa639e
URL: https://github.com/llvm/llvm-project/commit/38f34e587d10fcd7d18fd240e41248006faa639e
DIFF: https://github.com/llvm/llvm-project/commit/38f34e587d10fcd7d18fd240e41248006faa639e.diff
LOG: [mlir][Arith] Fix folder of CmpIOp to not fail when element type is not integer.
The folder used `cast<IntegerType>` which would segfault if the type were
a vector type. Handle this case appropriately and avoid failure.
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D137345
Added:
Modified:
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Dialect/Arith/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d1d03a549092d..2c0fc51d08a40 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -19,6 +19,7 @@
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::arith;
@@ -1444,6 +1445,16 @@ static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
return DenseElementsAttr::get(shapedType, boolAttr);
}
+static Optional<int64_t> getIntegerWidth(Type t) {
+ if (auto intType = t.dyn_cast<IntegerType>()) {
+ return intType.getWidth();
+ }
+ if (auto vectorIntType = t.dyn_cast<VectorType>()) {
+ return vectorIntType.getElementType().cast<IntegerType>().getWidth();
+ }
+ return llvm::None;
+}
+
OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "cmpi takes two operands");
@@ -1456,13 +1467,17 @@ OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
if (matchPattern(getRhs(), m_Zero())) {
if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
// extsi(%x : i1 -> iN) != 0 -> %x
- if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
+ Optional<int64_t> integerWidth =
+ getIntegerWidth(extOp.getOperand().getType());
+ if (integerWidth && integerWidth.value() == 1 &&
getPredicate() == arith::CmpIPredicate::ne)
return extOp.getOperand();
}
if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
// extui(%x : i1 -> iN) != 0 -> %x
- if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
+ Optional<int64_t> integerWidth =
+ getIntegerWidth(extOp.getOperand().getType());
+ if (integerWidth && integerWidth.value() == 1 &&
getPredicate() == arith::CmpIPredicate::ne)
return extOp.getOperand();
}
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 337eec00f3bf9..336324ef4eec9 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -162,7 +162,7 @@ func.func @cmpi_const_right(%arg0: i64)
// -----
-// CHECK-LABEL: @cmpOfExtSI
+// CHECK-LABEL: @cmpOfExtSI(
// CHECK-NEXT: return %arg0
func.func @cmpOfExtSI(%arg0: i1) -> i1 {
%ext = arith.extsi %arg0 : i1 to i64
@@ -171,7 +171,7 @@ func.func @cmpOfExtSI(%arg0: i1) -> i1 {
return %res : i1
}
-// CHECK-LABEL: @cmpOfExtUI
+// CHECK-LABEL: @cmpOfExtUI(
// CHECK-NEXT: return %arg0
func.func @cmpOfExtUI(%arg0: i1) -> i1 {
%ext = arith.extui %arg0 : i1 to i64
@@ -182,6 +182,26 @@ func.func @cmpOfExtUI(%arg0: i1) -> i1 {
// -----
+// CHECK-LABEL: @cmpOfExtSIVector(
+// CHECK-NEXT: return %arg0
+func.func @cmpOfExtSIVector(%arg0: vector<4xi1>) -> vector<4xi1> {
+ %ext = arith.extsi %arg0 : vector<4xi1> to vector<4xi64>
+ %c0 = arith.constant dense<0> : vector<4xi64>
+ %res = arith.cmpi ne, %ext, %c0 : vector<4xi64>
+ return %res : vector<4xi1>
+}
+
+// CHECK-LABEL: @cmpOfExtUIVector(
+// CHECK-NEXT: return %arg0
+func.func @cmpOfExtUIVector(%arg0: vector<4xi1>) -> vector<4xi1> {
+ %ext = arith.extui %arg0 : vector<4xi1> to vector<4xi64>
+ %c0 = arith.constant dense<0> : vector<4xi64>
+ %res = arith.cmpi ne, %ext, %c0 : vector<4xi64>
+ return %res : vector<4xi1>
+}
+
+// -----
+
// CHECK-LABEL: @extSIOfExtUI
// CHECK: %[[res:.+]] = arith.extui %arg0 : i1 to i64
// CHECK: return %[[res]]
@@ -1660,3 +1680,5 @@ func.func @xorxor3(%a : i32, %b : i32) -> i32 {
%res = arith.xori %b, %c : i32
return %res : i32
}
+
+// -----
More information about the Mlir-commits
mailing list