[Mlir-commits] [mlir] 9963f16 - [mlir][spirv] Migrate to new fold API

Jakub Kuderski llvmlistbot at llvm.org
Wed Jan 11 10:56:18 PST 2023


Author: Jakub Kuderski
Date: 2023-01-11T13:55:39-05:00
New Revision: 9963f166fb5ea8b03b4455b265db3fe04fbf4cdd

URL: https://github.com/llvm/llvm-project/commit/9963f166fb5ea8b03b4455b265db3fe04fbf4cdd
DIFF: https://github.com/llvm/llvm-project/commit/9963f166fb5ea8b03b4455b265db3fe04fbf4cdd.diff

LOG: [mlir][spirv] Migrate to new fold API

Fixes: https://github.com/llvm/llvm-project/issues/59938

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D141524

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 7ca32d92c583a..0ffb38fed6692 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -48,6 +48,7 @@ def SPIRV_Dialect : Dialect {
 
   let cppNamespace = "::mlir::spirv";
   let useDefaultTypePrinterParser = 1;
+  let useFoldAPI = kEmitFoldAdaptorFolder;
   let hasConstantMaterializer = 1;
   let hasOperationAttrVerify = 1;
   let hasRegionArgAttrVerify = 1;

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index e7d212b5c050c..5ea8a6778023d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -116,7 +116,7 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
 // spirv.BitcastOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult spirv::BitcastOp::fold(ArrayRef<Attribute> /*operands*/) {
+OpFoldResult spirv::BitcastOp::fold(FoldAdaptor /*adaptor*/) {
   Value curInput = getOperand();
   if (getType() == curInput.getType())
     return curInput;
@@ -139,7 +139,7 @@ OpFoldResult spirv::BitcastOp::fold(ArrayRef<Attribute> /*operands*/) {
 // spirv.CompositeExtractOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
   if (auto insertOp =
           getComposite().getDefiningOp<spirv::CompositeInsertOp>()) {
     if (getIndices() == insertOp.getIndices())
@@ -160,15 +160,14 @@ OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
       llvm::to_vector<8>(llvm::map_range(getIndices(), [](Attribute attr) {
         return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
       }));
-  return extractCompositeElement(operands[0], indexVector);
+  return extractCompositeElement(adaptor.getComposite(), indexVector);
 }
 
 //===----------------------------------------------------------------------===//
 // spirv.Constant
 //===----------------------------------------------------------------------===//
 
-OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.empty() && "spirv.Constant has no operands");
+OpFoldResult spirv::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
   return getValue();
 }
 
@@ -176,8 +175,7 @@ OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
 // spirv.IAdd
 //===----------------------------------------------------------------------===//
 
-OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 2 && "spirv.IAdd expects two operands");
+OpFoldResult spirv::IAddOp::fold(FoldAdaptor adaptor) {
   // x + 0 = x
   if (matchPattern(getOperand2(), m_Zero()))
     return getOperand1();
@@ -188,15 +186,15 @@ OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
   // R, where N is the component width and R is computed with enough precision
   // to avoid overflow and underflow.
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
+      adaptor.getOperands(),
+      [](APInt a, const APInt &b) { return std::move(a) + b; });
 }
 
 //===----------------------------------------------------------------------===//
 // spirv.IMul
 //===----------------------------------------------------------------------===//
 
-OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 2 && "spirv.IMul expects two operands");
+OpFoldResult spirv::IMulOp::fold(FoldAdaptor adaptor) {
   // x * 0 == 0
   if (matchPattern(getOperand2(), m_Zero()))
     return getOperand2();
@@ -210,14 +208,15 @@ OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
   // R, where N is the component width and R is computed with enough precision
   // to avoid overflow and underflow.
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [](const APInt &a, const APInt &b) { return a * b; });
+      adaptor.getOperands(),
+      [](const APInt &a, const APInt &b) { return a * b; });
 }
 
 //===----------------------------------------------------------------------===//
 // spirv.ISub
 //===----------------------------------------------------------------------===//
 
-OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
   // x - x = 0
   if (getOperand1() == getOperand2())
     return Builder(getContext()).getIntegerAttr(getType(), 0);
@@ -228,24 +227,23 @@ OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
   // R, where N is the component width and R is computed with enough precision
   // to avoid overflow and underflow.
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
+      adaptor.getOperands(),
+      [](APInt a, const APInt &b) { return std::move(a) - b; });
 }
 
 //===----------------------------------------------------------------------===//
 // spirv.LogicalAnd
 //===----------------------------------------------------------------------===//
 
-OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 2 && "spirv.LogicalAnd should take two operands");
-
-  if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
+OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
+  if (Optional<bool> rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
     // x && true = x
     if (*rhs)
       return getOperand1();
 
     // x && false = false
     if (!*rhs)
-      return operands.back();
+      return adaptor.getOperand2();
   }
 
   return Attribute();
@@ -255,11 +253,8 @@ OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
 // spirv.LogicalNotEqualOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult spirv::LogicalNotEqualOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 2 &&
-         "spirv.LogicalNotEqual should take two operands");
-
-  if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
+OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
+  if (Optional<bool> rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
     // x && false = x
     if (!rhs.value())
       return getOperand1();
@@ -284,13 +279,11 @@ void spirv::LogicalNotOp::getCanonicalizationPatterns(
 // spirv.LogicalOr
 //===----------------------------------------------------------------------===//
 
-OpFoldResult spirv::LogicalOrOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 2 && "spirv.LogicalOr should take two operands");
-
-  if (auto rhs = getScalarOrSplatBoolAttr(operands.back())) {
+OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
+  if (auto rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
     if (*rhs)
       // x || true = true
-      return operands.back();
+      return adaptor.getOperand2();
 
     // x || false = x
     if (!*rhs)


        


More information about the Mlir-commits mailing list