[Mlir-commits] [mlir] 9dd4c2d - [mlir][vector] Add constant folder for vector.shuffle ops
Lei Zhang
llvmlistbot at llvm.org
Fri Feb 4 13:59:38 PST 2022
Author: Lei Zhang
Date: 2022-02-04T16:59:32-05:00
New Revision: 9dd4c2dcb63a41a20746b74781bd5ece627c47a8
URL: https://github.com/llvm/llvm-project/commit/9dd4c2dcb63a41a20746b74781bd5ece627c47a8
DIFF: https://github.com/llvm/llvm-project/commit/9dd4c2dcb63a41a20746b74781bd5ece627c47a8.diff
LOG: [mlir][vector] Add constant folder for vector.shuffle ops
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D119032
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index d2efe2c30962a..6feb8e0fa4a8a 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -465,6 +465,7 @@ def Vector_ShuffleOp :
let builders = [
OpBuilder<(ins "Value":$v1, "Value":$v2, "ArrayRef<int64_t>")>
];
+ let hasFolder = 1;
let extraClassDeclaration = [{
static StringRef getMaskAttrName() { return "mask"; }
VectorType getV1VectorType() {
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c2515f706122e..b576005c47e9b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1803,6 +1803,33 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
return success();
}
+OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) {
+ Attribute lhs = operands.front(), rhs = operands.back();
+ if (!lhs || !rhs)
+ return {};
+
+ auto lhsType = lhs.getType().cast<VectorType>();
+ // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
+ // manipulation.
+ if (lhsType.getRank() != 1)
+ return {};
+ int64_t lhsSize = lhsType.getDimSize(0);
+
+ SmallVector<Attribute> results;
+ auto lhsElements = lhs.cast<DenseElementsAttr>().getValues<Attribute>();
+ auto rhsElements = rhs.cast<DenseElementsAttr>().getValues<Attribute>();
+ for (const auto &index : this->mask().getAsValueRange<IntegerAttr>()) {
+ int64_t i = index.getZExtValue();
+ if (i >= lhsSize) {
+ results.push_back(rhsElements[i - lhsSize]);
+ } else {
+ results.push_back(lhsElements[i]);
+ }
+ }
+
+ return DenseElementsAttr::get(getVectorType(), results);
+}
+
//===----------------------------------------------------------------------===//
// InsertElementOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a25b0687ca756..522f8dea8b470 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1254,3 +1254,15 @@ func @splat_fold() -> vector<4xf32> {
// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
// CHECK-NEXT: return [[V]] : vector<4xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @shuffle_1d
+// CHECK: %[[V:.+]] = arith.constant dense<[3, 2, 5, 1]> : vector<4xi32>
+// CHECK: return %[[V]]
+func @shuffle_1d() -> vector<4xi32> {
+ %v0 = arith.constant dense<[0, 1, 2]> : vector<3xi32>
+ %v1 = arith.constant dense<[3, 4, 5]> : vector<3xi32>
+ %shuffle = vector.shuffle %v0, %v1 [3, 2, 5, 1] : vector<3xi32>, vector<3xi32>
+ return %shuffle : vector<4xi32>
+}
More information about the Mlir-commits
mailing list