[Mlir-commits] [mlir] 51afca6 - [mlir][vector] Simplify fold pattern for ExtractOp(constant). NFC.
Jakub Kuderski
llvmlistbot at llvm.org
Wed Nov 23 15:58:27 PST 2022
Author: Jakub Kuderski
Date: 2022-11-23T18:57:21-05:00
New Revision: 51afca640c2968c51a92dfb89e67e10bdcb98216
URL: https://github.com/llvm/llvm-project/commit/51afca640c2968c51a92dfb89e67e10bdcb98216
DIFF: https://github.com/llvm/llvm-project/commit/51afca640c2968c51a92dfb89e67e10bdcb98216.diff
LOG: [mlir][vector] Simplify fold pattern for ExtractOp(constant). NFC.
Use helper functions. Reuse array element attributes.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D138609
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bd96ee7de24f7..22d7bdc3542e4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1623,7 +1623,6 @@ class ExtractOpScalarVectorConstantFolder final
return failure();
auto vecTy = sourceVector.getType().cast<VectorType>();
- Type elemTy = vecTy.getElementType();
ArrayAttr positions = extractOp.getPosition();
if (vecTy.isScalable())
return failure();
@@ -1631,36 +1630,17 @@ class ExtractOpScalarVectorConstantFolder final
// constants.
if (vecTy.getRank() != static_cast<int64_t>(positions.size()))
return failure();
- // TODO: Handle more element types, e.g., complex values.
- if (!elemTy.isIntOrIndexOrFloat())
- return failure();
// The splat case is handled by `ExtractOpSplatConstantFolder`.
auto dense = vectorCst.dyn_cast<DenseElementsAttr>();
if (!dense || dense.isSplat())
return failure();
- // Calculate the flattened position.
- int64_t elemPosition = 0;
- int64_t innerElems = 1;
- for (auto [dimSize, positionInDim] :
- llvm::reverse(llvm::zip(vecTy.getShape(), positions))) {
- int64_t positionVal = positionInDim.cast<IntegerAttr>().getInt();
- elemPosition += positionVal * innerElems;
- innerElems *= dimSize;
- }
-
- Attribute newAttr;
- if (vecTy.getElementType().isIntOrIndex()) {
- auto values = to_vector(dense.getValues<APInt>());
- newAttr = IntegerAttr::get(extractOp.getType(), values[elemPosition]);
- } else if (vecTy.getElementType().isa<FloatType>()) {
- auto values = to_vector(dense.getValues<APFloat>());
- newAttr = FloatAttr::get(extractOp.getType(), values[elemPosition]);
- }
- assert(newAttr && "Unhandled case");
-
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
+ // Calculate the linearized position.
+ int64_t elemPosition =
+ linearize(getI64SubArray(positions), computeStrides(vecTy.getShape()));
+ Attribute elementValue = *(dense.value_begin<Attribute>() + elemPosition);
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, elementValue);
return success();
}
};
More information about the Mlir-commits
mailing list