[Mlir-commits] [mlir] 0813700 - [mlir][NFC] Cleanup: Move helper functions to StaticValueUtils

Matthias Springer llvmlistbot at llvm.org
Sat Jun 26 23:57:47 PDT 2021


Author: Matthias Springer
Date: 2021-06-27T15:56:48+09:00
New Revision: 0813700de1af72173ad18202fcbd3eafce90d184

URL: https://github.com/llvm/llvm-project/commit/0813700de1af72173ad18202fcbd3eafce90d184
DIFF: https://github.com/llvm/llvm-project/commit/0813700de1af72173ad18202fcbd3eafce90d184.diff

LOG: [mlir][NFC] Cleanup: Move helper functions to StaticValueUtils

Reduce code duplication: Move various helper functions, that are duplicated in TensorDialect, MemRefDialect, LinalgDialect, StandardDialect, into a new StaticValueUtils.cpp.

Differential Revision: https://reviews.llvm.org/D104687

Added: 
    mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
    mlir/lib/Dialect/Utils/StaticValueUtils.cpp

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
    mlir/include/mlir/Interfaces/ViewLikeInterface.h
    mlir/include/mlir/Interfaces/ViewLikeInterface.td
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/Dialect/Utils/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index df568d6795d43..ffd65f7138efc 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -269,13 +269,13 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
     // Return true if low padding is guaranteed to be 0.
     bool hasZeroLowPad() {
       return llvm::all_of(getMixedLowPad(), [](OpFoldResult ofr) {
-        return mlir::isEqualConstantInt(ofr, 0);
+        return getConstantIntValue(ofr) == static_cast<int64_t>(0);
       });
     }
     // Return true if high padding is guaranteed to be 0.
     bool hasZeroHighPad() {
       return llvm::all_of(getMixedHighPad(), [](OpFoldResult ofr) {
-        return mlir::isEqualConstantInt(ofr, 0);
+        return getConstantIntValue(ofr) == static_cast<int64_t>(0);
       });
     }
   }];

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index f6b78ae385d04..5f533df137419 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/Identifier.h"
 #include "mlir/IR/PatternMatch.h"

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index bff62c716dfe6..477474b41da46 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -114,21 +114,6 @@ bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
 bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
                        const APFloat &rhs);
 
-/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an
-/// IntegerAttr, return the integer.
-llvm::Optional<int64_t> getConstantIntValue(OpFoldResult ofr);
-
-/// Return true if ofr and value are the same integer.
-/// Ignore integer bitwidth and type mismatch that come from the fact there is
-/// no IndexAttr and that IndexType has no bitwidth.
-bool isEqualConstantInt(OpFoldResult ofr, int64_t value);
-
-/// Return true if ofr1 and ofr2 are the same integer constant attribute values
-/// or the same SSA value.
-/// Ignore integer bitwitdh and type mismatch that come from the fact there is
-/// no IndexAttr and that IndexType have no bitwidth.
-bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2);
-
 /// Returns the identity value attribute associated with an AtomicRMWKind op.
 Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
                                OpBuilder &builder, Location loc);

diff  --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
new file mode 100644
index 0000000000000..3284c022a7255
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -0,0 +1,58 @@
+//===- StaticValueUtils.h - Utilities for static values ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines utilities for dealing with static values, e.g.,
+// converting back and forth between Value and OpFoldResult. Such functionality
+// is used in multiple dialects.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
+#define MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+
+/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
+/// it is a Value or into `staticVec` if it is an IntegerAttr.
+/// In the case of a Value, a copy of the `sentinel` value is also pushed to
+/// `staticVec`. This is useful to extract mixed static and dynamic entries that
+/// come from an AttrSizedOperandSegments trait.
+void dispatchIndexOpFoldResult(OpFoldResult ofr,
+                               SmallVectorImpl<Value> &dynamicVec,
+                               SmallVectorImpl<int64_t> &staticVec,
+                               int64_t sentinel);
+
+/// Helper function to dispatch multiple OpFoldResults into either the
+/// `dynamicVec` (for Values) or into `staticVec` (for IntegerAttrs).
+/// In the case of a Value, a copy of the `sentinel` value is also pushed to
+/// `staticVec`. This is useful to extract mixed static and dynamic entries that
+/// come from an AttrSizedOperandSegments trait.
+void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
+                                SmallVectorImpl<Value> &dynamicVec,
+                                SmallVectorImpl<int64_t> &staticVec,
+                                int64_t sentinel);
+
+/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
+SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr);
+
+/// If ofr is a constant integer or an IntegerAttr, return the integer.
+Optional<int64_t> getConstantIntValue(OpFoldResult ofr);
+
+/// Return true if ofr1 and ofr2 are the same integer constant attribute values
+/// or the same SSA value.
+/// Ignore integer bitwitdh and type mismatch that come from the fact there is
+/// no IndexAttr and that IndexType have no bitwidth.
+bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2);
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H

diff  --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 8d58570148910..7df3d1e95bab4 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
 #define MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
 
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -30,8 +31,6 @@ struct Range {
 
 class OffsetSizeAndStrideOpInterface;
 
-bool isEqualConstantInt(OpFoldResult ofr, int64_t value);
-
 namespace detail {
 LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op);
 

diff  --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
index 62f24f2b97362..2ba9038cec775 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
@@ -444,7 +444,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         return ::llvm::all_of(getMixedStrides(), [](OpFoldResult ofr) {
-          return ::mlir::isEqualConstantInt(ofr, 1);
+          return ::mlir::getConstantIntValue(ofr) == static_cast<int64_t>(1);
         });
       }]
     >,
@@ -456,7 +456,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         return ::llvm::all_of(getMixedOffsets(), [](OpFoldResult ofr) {
-          return ::mlir::isEqualConstantInt(ofr, 0);
+          return ::mlir::getConstantIntValue(ofr) == static_cast<int64_t>(0);
         });
       }]
     >,

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 8e808d75e205e..db5918e95f182 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
@@ -3388,14 +3389,6 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
   }
 };
 
-/// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr.
-static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
-  return llvm::to_vector<4>(
-      llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
-        return a.cast<IntegerAttr>().getInt();
-      }));
-}
-
 /// Conversion pattern that transforms a subview op into:
 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
 ///   2. Updates to the descriptor to introduce the data ptr, offset, size

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9a1ceebba97d5..109a1c60ddc39 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
@@ -116,24 +117,6 @@ static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
       }));
 }
 
-/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
-/// it is a Value or into `staticVec` if it is an IntegerAttr.
-/// In the case of a Value, a copy of the `sentinel` value is also pushed to
-/// `staticVec`. This is useful to extract mixed static and dynamic entries that
-/// come from an AttrSizedOperandSegments trait.
-static void dispatchIndexOpFoldResult(OpFoldResult ofr,
-                                      SmallVectorImpl<Value> &dynamicVec,
-                                      SmallVectorImpl<int64_t> &staticVec,
-                                      int64_t sentinel) {
-  if (auto v = ofr.dyn_cast<Value>()) {
-    dynamicVec.push_back(v);
-    staticVec.push_back(sentinel);
-    return;
-  }
-  APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
-  staticVec.push_back(apInt.getSExtValue());
-}
-
 /// This is a common class used for patterns of the form
 /// ```
 ///    someop(memrefcast(%src)) -> someop(%src)
@@ -819,14 +802,6 @@ LogicalResult InitTensorOp::reifyReturnTypeShapesPerResultDim(
 // PadTensorOp
 //===----------------------------------------------------------------------===//
 
-/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
-static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
-  return llvm::to_vector<4>(
-      llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
-        return a.cast<IntegerAttr>().getInt();
-      }));
-}
-
 static LogicalResult verify(PadTensorOp op) {
   auto sourceType = op.source().getType().cast<RankedTensorType>();
   auto resultType = op.result().getType().cast<RankedTensorType>();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index d02570af3622b..c951e70f18d83 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -110,6 +110,7 @@
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/Pass/Pass.h"

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 829b988dbad73..92382a6906835 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -814,8 +814,8 @@ struct GenericPadTensorOpVectorizationPattern
         readInBounds.push_back(false);
         // Write is out-of-bounds if low padding > 0.
         writeInBounds.push_back(
-            isEqualConstantIntOrValue(padOp.getMixedLowPad()[i],
-                                      rewriter.getIndexAttr(0)));
+            getConstantIntValue(padOp.getMixedLowPad()[i]) ==
+            static_cast<int64_t>(0));
       } else {
         // Neither source nor result dim of padOp is static. Cannot vectorize
         // the copy.
@@ -1098,9 +1098,9 @@ struct PadTensorOpVectorizationWithInsertSlicePattern
     SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
     expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
     if (!llvm::all_of(
-            llvm::zip(insertOp.getMixedSizes(), expectedSizes),
-            [](auto it) { return isEqualConstantInt(std::get<0>(it),
-                                                    std::get<1>(it)); }))
+            llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
+              return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
+            }))
       return failure();
 
     // Generate TransferReadOp: Read entire source tensor and add high padding.

diff  --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index 6ac47b11996a3..6f9aeaa19cb22 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRMemRef
 
   LINK_LIBS PUBLIC
   MLIRDialect
+  MLIRDialectUtils
   MLIRInferTypeOpInterface
   MLIRIR
   MLIRMemRefUtils

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 8d003577eb533..cc4e7a49363a5 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -32,40 +33,6 @@ Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
   return builder.create<mlir::ConstantOp>(loc, type, value);
 }
 
-/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
-static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
-  return llvm::to_vector<4>(
-      llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
-        return a.cast<IntegerAttr>().getInt();
-      }));
-}
-
-/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
-/// it is a Value or into `staticVec` if it is an IntegerAttr.
-/// In the case of a Value, a copy of the `sentinel` value is also pushed to
-/// `staticVec`. This is useful to extract mixed static and dynamic entries that
-/// come from an AttrSizedOperandSegments trait.
-static void dispatchIndexOpFoldResult(OpFoldResult ofr,
-                                      SmallVectorImpl<Value> &dynamicVec,
-                                      SmallVectorImpl<int64_t> &staticVec,
-                                      int64_t sentinel) {
-  if (auto v = ofr.dyn_cast<Value>()) {
-    dynamicVec.push_back(v);
-    staticVec.push_back(sentinel);
-    return;
-  }
-  APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
-  staticVec.push_back(apInt.getSExtValue());
-}
-
-static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
-                                       SmallVectorImpl<Value> &dynamicVec,
-                                       SmallVectorImpl<int64_t> &staticVec,
-                                       int64_t sentinel) {
-  for (auto ofr : ofrs)
-    dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
-}
-
 //===----------------------------------------------------------------------===//
 // Common canonicalization pattern support logic
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 73c0c4a607b63..837986fc03535 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -33,38 +33,6 @@
 
 using namespace mlir;
 
-/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an
-/// IntegerAttr, return the integer.
-llvm::Optional<int64_t> mlir::getConstantIntValue(OpFoldResult ofr) {
-  Attribute attr = ofr.dyn_cast<Attribute>();
-  // Note: isa+cast-like pattern allows writing the condition below as 1 line.
-  if (!attr && ofr.get<Value>().getDefiningOp<ConstantOp>())
-    attr = ofr.get<Value>().getDefiningOp<ConstantOp>().getValue();
-  if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
-    return intAttr.getValue().getSExtValue();
-  return llvm::None;
-}
-
-/// Return true if ofr and value are the same integer.
-/// Ignore integer bitwidth and type mismatch that come from the fact there is
-/// no IndexAttr and that IndexType has no bitwidth.
-bool mlir::isEqualConstantInt(OpFoldResult ofr, int64_t value) {
-  auto ofrValue = getConstantIntValue(ofr);
-  return ofrValue && *ofrValue == value;
-}
-
-/// Return true if ofr1 and ofr2 are the same integer constant attribute values
-/// or the same SSA value.
-/// Ignore integer bitwidth and type mismatch that come from the fact there is
-/// no IndexAttr and that IndexType has no bitwidth.
-bool mlir::isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
-  auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
-  if (cst1 && cst2 && *cst1 == *cst2)
-    return true;
-  auto v1 = ofr1.dyn_cast<Value>(), v2 = ofr2.dyn_cast<Value>();
-  return v1 && v2 && v1 == v2;
-}
-
 //===----------------------------------------------------------------------===//
 // StandardOpsDialect Interfaces
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
index e1fad1b358f00..4b6886ef244d0 100644
--- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRTensor
 
   LINK_LIBS PUBLIC
   MLIRCastInterfaces
+  MLIRDialectUtils
   MLIRIR
   MLIRSideEffectInterfaces
   MLIRSupport

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 8a4b212db0329..28a5f5df21cef 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Matchers.h"
@@ -516,32 +517,6 @@ static LogicalResult verify(ReshapeOp op) {
 // ExtractSliceOp
 //===----------------------------------------------------------------------===//
 
-/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
-/// it is a Value or into `staticVec` if it is an IntegerAttr.
-/// In the case of a Value, a copy of the `sentinel` value is also pushed to
-/// `staticVec`. This is useful to extract mixed static and dynamic entries that
-/// come from an AttrSizedOperandSegments trait.
-static void dispatchIndexOpFoldResult(OpFoldResult ofr,
-                                      SmallVectorImpl<Value> &dynamicVec,
-                                      SmallVectorImpl<int64_t> &staticVec,
-                                      int64_t sentinel) {
-  if (auto v = ofr.dyn_cast<Value>()) {
-    dynamicVec.push_back(v);
-    staticVec.push_back(sentinel);
-    return;
-  }
-  APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
-  staticVec.push_back(apInt.getSExtValue());
-}
-
-static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
-                                       SmallVectorImpl<Value> &dynamicVec,
-                                       SmallVectorImpl<int64_t> &staticVec,
-                                       int64_t sentinel) {
-  for (auto ofr : ofrs)
-    dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
-}
-
 /// An extract_slice op result type can be fully inferred from the source type
 /// and the static representation of offsets, sizes and strides. Special
 /// sentinels encode the dynamic case.
@@ -563,14 +538,6 @@ Type ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
                                sourceRankedTensorType.getElementType());
 }
 
-/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
-static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
-  return llvm::to_vector<4>(
-      llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
-        return a.cast<IntegerAttr>().getInt();
-      }));
-}
-
 Type ExtractSliceOp::inferResultType(
     RankedTensorType sourceRankedTensorType,
     ArrayRef<OpFoldResult> leadingStaticOffsets,
@@ -890,17 +857,16 @@ foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
                                            ShapedType shapedType) {
   OpBuilder b(op.getContext());
   for (OpFoldResult ofr : op.getMixedOffsets())
-    if (!isEqualConstantIntOrValue(ofr, b.getIndexAttr(0)))
+    if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
       return failure();
   // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip
   // is appropriate.
   auto shape = shapedType.getShape();
   for (auto it : llvm::zip(op.getMixedSizes(), shape))
-    if (!isEqualConstantIntOrValue(std::get<0>(it),
-                                   b.getIndexAttr(std::get<1>(it))))
+    if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
       return failure();
   for (OpFoldResult ofr : op.getMixedStrides())
-    if (!isEqualConstantIntOrValue(ofr, b.getIndexAttr(1)))
+    if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
       return failure();
   return success();
 }

diff  --git a/mlir/lib/Dialect/Utils/CMakeLists.txt b/mlir/lib/Dialect/Utils/CMakeLists.txt
index a640e3581b4a3..098b6b48b032f 100644
--- a/mlir/lib/Dialect/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Utils/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_library(MLIRDialectUtils
   StructuredOpsUtils.cpp
+  StaticValueUtils.cpp
 
   LINK_LIBS PUBLIC
   MLIRIR

diff  --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
new file mode 100644
index 0000000000000..bf7d662dbfcc9
--- /dev/null
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -0,0 +1,79 @@
+//===- StaticValueUtils.cpp - Utilities for dealing with static values ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/APSInt.h"
+
+namespace mlir {
+
+/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
+/// it is a Value or into `staticVec` if it is an IntegerAttr.
+/// In the case of a Value, a copy of the `sentinel` value is also pushed to
+/// `staticVec`. This is useful to extract mixed static and dynamic entries that
+/// come from an AttrSizedOperandSegments trait.
+void dispatchIndexOpFoldResult(OpFoldResult ofr,
+                               SmallVectorImpl<Value> &dynamicVec,
+                               SmallVectorImpl<int64_t> &staticVec,
+                               int64_t sentinel) {
+  if (auto v = ofr.dyn_cast<Value>()) {
+    dynamicVec.push_back(v);
+    staticVec.push_back(sentinel);
+    return;
+  }
+  APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
+  staticVec.push_back(apInt.getSExtValue());
+}
+
+void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
+                                SmallVectorImpl<Value> &dynamicVec,
+                                SmallVectorImpl<int64_t> &staticVec,
+                                int64_t sentinel) {
+  for (OpFoldResult ofr : ofrs)
+    dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
+}
+
+/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
+SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
+  return llvm::to_vector<4>(
+      llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
+        return a.cast<IntegerAttr>().getInt();
+      }));
+}
+
+/// If ofr is a constant integer or an IntegerAttr, return the integer.
+Optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
+  // Case 1: Check for Constant integer.
+  if (auto val = ofr.dyn_cast<Value>()) {
+    APSInt intVal;
+    if (matchPattern(val, m_ConstantInt(&intVal)))
+      return intVal.getSExtValue();
+    return llvm::None;
+  }
+  // Case 2: Check for IntegerAttr.
+  Attribute attr = ofr.dyn_cast<Attribute>();
+  if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
+    return intAttr.getValue().getSExtValue();
+  return llvm::None;
+}
+
+/// Return true if ofr1 and ofr2 are the same integer constant attribute values
+/// or the same SSA value.
+/// Ignore integer bitwidth and type mismatch that come from the fact there is
+/// no IndexAttr and that IndexType has no bitwidth.
+bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
+  auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
+  if (cst1 && cst2 && *cst1 == *cst2)
+    return true;
+  auto v1 = ofr1.dyn_cast<Value>(), v2 = ofr2.dyn_cast<Value>();
+  return v1 && v1 == v2;
+}
+
+} // namespace mlir
+


        


More information about the Mlir-commits mailing list