[Mlir-commits] [mlir] 9963f16 - [mlir][spirv] Migrate to new fold API
Jakub Kuderski
llvmlistbot at llvm.org
Wed Jan 11 10:56:18 PST 2023
Author: Jakub Kuderski
Date: 2023-01-11T13:55:39-05:00
New Revision: 9963f166fb5ea8b03b4455b265db3fe04fbf4cdd
URL: https://github.com/llvm/llvm-project/commit/9963f166fb5ea8b03b4455b265db3fe04fbf4cdd
DIFF: https://github.com/llvm/llvm-project/commit/9963f166fb5ea8b03b4455b265db3fe04fbf4cdd.diff
LOG: [mlir][spirv] Migrate to new fold API
Fixes: https://github.com/llvm/llvm-project/issues/59938
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D141524
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 7ca32d92c583a..0ffb38fed6692 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -48,6 +48,7 @@ def SPIRV_Dialect : Dialect {
let cppNamespace = "::mlir::spirv";
let useDefaultTypePrinterParser = 1;
+ let useFoldAPI = kEmitFoldAdaptorFolder;
let hasConstantMaterializer = 1;
let hasOperationAttrVerify = 1;
let hasRegionArgAttrVerify = 1;
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index e7d212b5c050c..5ea8a6778023d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -116,7 +116,7 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
// spirv.BitcastOp
//===----------------------------------------------------------------------===//
-OpFoldResult spirv::BitcastOp::fold(ArrayRef<Attribute> /*operands*/) {
+OpFoldResult spirv::BitcastOp::fold(FoldAdaptor /*adaptor*/) {
Value curInput = getOperand();
if (getType() == curInput.getType())
return curInput;
@@ -139,7 +139,7 @@ OpFoldResult spirv::BitcastOp::fold(ArrayRef<Attribute> /*operands*/) {
// spirv.CompositeExtractOp
//===----------------------------------------------------------------------===//
-OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
if (auto insertOp =
getComposite().getDefiningOp<spirv::CompositeInsertOp>()) {
if (getIndices() == insertOp.getIndices())
@@ -160,15 +160,14 @@ OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
llvm::to_vector<8>(llvm::map_range(getIndices(), [](Attribute attr) {
return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
}));
- return extractCompositeElement(operands[0], indexVector);
+ return extractCompositeElement(adaptor.getComposite(), indexVector);
}
//===----------------------------------------------------------------------===//
// spirv.Constant
//===----------------------------------------------------------------------===//
-OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.empty() && "spirv.Constant has no operands");
+OpFoldResult spirv::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
return getValue();
}
@@ -176,8 +175,7 @@ OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
// spirv.IAdd
//===----------------------------------------------------------------------===//
-OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "spirv.IAdd expects two operands");
+OpFoldResult spirv::IAddOp::fold(FoldAdaptor adaptor) {
// x + 0 = x
if (matchPattern(getOperand2(), m_Zero()))
return getOperand1();
@@ -188,15 +186,15 @@ OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
// R, where N is the component width and R is computed with enough precision
// to avoid overflow and underflow.
return constFoldBinaryOp<IntegerAttr>(
- operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
+ adaptor.getOperands(),
+ [](APInt a, const APInt &b) { return std::move(a) + b; });
}
//===----------------------------------------------------------------------===//
// spirv.IMul
//===----------------------------------------------------------------------===//
-OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "spirv.IMul expects two operands");
+OpFoldResult spirv::IMulOp::fold(FoldAdaptor adaptor) {
// x * 0 == 0
if (matchPattern(getOperand2(), m_Zero()))
return getOperand2();
@@ -210,14 +208,15 @@ OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
// R, where N is the component width and R is computed with enough precision
// to avoid overflow and underflow.
return constFoldBinaryOp<IntegerAttr>(
- operands, [](const APInt &a, const APInt &b) { return a * b; });
+ adaptor.getOperands(),
+ [](const APInt &a, const APInt &b) { return a * b; });
}
//===----------------------------------------------------------------------===//
// spirv.ISub
//===----------------------------------------------------------------------===//
-OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
// x - x = 0
if (getOperand1() == getOperand2())
return Builder(getContext()).getIntegerAttr(getType(), 0);
@@ -228,24 +227,23 @@ OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
// R, where N is the component width and R is computed with enough precision
// to avoid overflow and underflow.
return constFoldBinaryOp<IntegerAttr>(
- operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
+ adaptor.getOperands(),
+ [](APInt a, const APInt &b) { return std::move(a) - b; });
}
//===----------------------------------------------------------------------===//
// spirv.LogicalAnd
//===----------------------------------------------------------------------===//
-OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "spirv.LogicalAnd should take two operands");
-
- if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
+OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
+ if (Optional<bool> rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
// x && true = x
if (*rhs)
return getOperand1();
// x && false = false
if (!*rhs)
- return operands.back();
+ return adaptor.getOperand2();
}
return Attribute();
@@ -255,11 +253,8 @@ OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
// spirv.LogicalNotEqualOp
//===----------------------------------------------------------------------===//
-OpFoldResult spirv::LogicalNotEqualOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 &&
- "spirv.LogicalNotEqual should take two operands");
-
- if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
+OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
+ if (Optional<bool> rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
// x && false = x
if (!rhs.value())
return getOperand1();
@@ -284,13 +279,11 @@ void spirv::LogicalNotOp::getCanonicalizationPatterns(
// spirv.LogicalOr
//===----------------------------------------------------------------------===//
-OpFoldResult spirv::LogicalOrOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "spirv.LogicalOr should take two operands");
-
- if (auto rhs = getScalarOrSplatBoolAttr(operands.back())) {
+OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
+ if (auto rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
if (*rhs)
// x || true = true
- return operands.back();
+ return adaptor.getOperand2();
// x || false = x
if (!*rhs)
More information about the Mlir-commits
mailing list