[Mlir-commits] [mlir] 937e40a - [mlir] Remove the non-templated DenseElementsAttr::getSplatValue
River Riddle
llvmlistbot at llvm.org
Mon Nov 8 17:56:51 PST 2021
Author: River Riddle
Date: 2021-11-09T01:40:40Z
New Revision: 937e40a8cf14ae2bb0545f23c3a32383f68d343a
URL: https://github.com/llvm/llvm-project/commit/937e40a8cf14ae2bb0545f23c3a32383f68d343a
DIFF: https://github.com/llvm/llvm-project/commit/937e40a8cf14ae2bb0545f23c3a32383f68d343a.diff
LOG: [mlir] Remove the non-templated DenseElementsAttr::getSplatValue
This predates the templated variant, and has been simply forwarding
to getSplatValue<Attribute> for some time. Removing this makes the
API a bit more uniform, and also helps prevent users from thinking
it is "cheap".
Added:
Modified:
mlir/include/mlir/IR/BuiltinAttributes.h
mlir/include/mlir/IR/Matchers.h
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 37da2eb9150b2..ba0fbe49239fc 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -353,7 +353,6 @@ class DenseElementsAttr : public Attribute {
/// Return the splat value for this attribute. This asserts that the attribute
/// corresponds to a splat.
- Attribute getSplatValue() const { return getSplatValue<Attribute>(); }
template <typename T>
typename std::enable_if<!std::is_base_of<Attribute, T>::value ||
std::is_same<Attribute, T>::value,
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 548cbe3208c95..1cac3eaa5914c 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -110,7 +110,7 @@ struct constant_int_op_binder {
if (type.isa<VectorType, RankedTensorType>()) {
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
return attr_value_binder<IntegerAttr>(bind_value)
- .match(splatAttr.getSplatValue());
+ .match(splatAttr.getSplatValue<Attribute>());
}
}
return false;
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index b4ea696f80d0c..521b3fcab0c6f 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -451,7 +451,7 @@ struct GlobalMemrefOpLowering
// For scalar memrefs, the global variable created is of the element type,
// so unpack the elements attribute to extract the value.
if (type.getRank() == 0)
- initialValue = elementsAttr.getValues<Attribute>()[0];
+ initialValue = elementsAttr.getSplatValue<Attribute>();
}
uint64_t alignment = global.alignment().getValueOr(0);
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index b97a04638a653..a9f3c7da9c842 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -349,7 +349,8 @@ static void convertConstantOp(arith::ConstantOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
assert(constantSupportsMMAMatrixType(op));
OpBuilder b(op);
- Attribute splat = op.getValue().cast<SplatElementsAttr>().getSplatValue();
+ Attribute splat =
+ op.getValue().cast<SplatElementsAttr>().getSplatValue<Attribute>();
auto scalarConstant =
b.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
const char *fragType = inferFragType(op);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 8831a8d70f743..566b3a2c5b306 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1574,7 +1574,7 @@ static bool isZeroAttribute(Attribute value) {
if (auto fpValue = value.dyn_cast<FloatAttr>())
return fpValue.getValue().isZero();
if (auto splatValue = value.dyn_cast<SplatElementsAttr>())
- return isZeroAttribute(splatValue.getSplatValue());
+ return isZeroAttribute(splatValue.getSplatValue<Attribute>());
if (auto elementsValue = value.dyn_cast<ElementsAttr>())
return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute);
if (auto arrayValue = value.dyn_cast<ArrayAttr>())
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index bda4ee7899c2b..703bc9c6d5b6f 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1395,7 +1395,7 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
if (operands[0].getType().isIntOrIndexOrFloat())
return DenseElementsAttr::get(vectorType, operands[0]);
if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
- return DenseElementsAttr::get(vectorType, attr.getSplatValue());
+ return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
return {};
}
@@ -2212,7 +2212,7 @@ class StridedSliceConstantFolder final
if (!dense)
return failure();
auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(),
- dense.getSplatValue());
+ dense.getSplatValue<Attribute>());
rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
newAttr);
return success();
@@ -3670,8 +3670,9 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
if (!dense)
return failure();
- auto newAttr = DenseElementsAttr::get(
- shapeCastOp.getType().cast<VectorType>(), dense.getSplatValue());
+ auto newAttr =
+ DenseElementsAttr::get(shapeCastOp.getType().cast<VectorType>(),
+ dense.getSplatValue<Attribute>());
rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
return success();
}
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 39b83c2255654..db0e7c2d84780 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -139,7 +139,7 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
if (denseElementsAttr.isSplat() &&
(type.isa<VectorType>() || hasVectorElementType)) {
llvm::Constant *splatValue = LLVM::detail::getLLVMConstant(
- innermostLLVMType, denseElementsAttr.getSplatValue(), loc,
+ innermostLLVMType, denseElementsAttr.getSplatValue<Attribute>(), loc,
moduleTranslation, /*isTopLevel=*/false);
llvm::Constant *splatVector =
llvm::ConstantDataVector::getSplat(0, splatValue);
@@ -254,8 +254,9 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
isa<llvm::ArrayType, llvm::VectorType>(elementType);
llvm::Constant *child = getLLVMConstant(
elementType,
- elementTypeSequential ? splatAttr : splatAttr.getSplatValue(), loc,
- moduleTranslation, false);
+ elementTypeSequential ? splatAttr
+ : splatAttr.getSplatValue<Attribute>(),
+ loc, moduleTranslation, false);
if (!child)
return nullptr;
if (llvmType->isVectorTy())
More information about the Mlir-commits
mailing list