[Mlir-commits] [mlir] 9dd4c2d - [mlir][vector] Add constant folder for vector.shuffle ops

Lei Zhang llvmlistbot at llvm.org
Fri Feb 4 13:59:38 PST 2022


Author: Lei Zhang
Date: 2022-02-04T16:59:32-05:00
New Revision: 9dd4c2dcb63a41a20746b74781bd5ece627c47a8

URL: https://github.com/llvm/llvm-project/commit/9dd4c2dcb63a41a20746b74781bd5ece627c47a8
DIFF: https://github.com/llvm/llvm-project/commit/9dd4c2dcb63a41a20746b74781bd5ece627c47a8.diff

LOG: [mlir][vector] Add constant folder for vector.shuffle ops

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D119032

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index d2efe2c30962a..6feb8e0fa4a8a 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -465,6 +465,7 @@ def Vector_ShuffleOp :
   let builders = [
     OpBuilder<(ins "Value":$v1, "Value":$v2, "ArrayRef<int64_t>")>
   ];
+  let hasFolder = 1;
   let extraClassDeclaration = [{
     static StringRef getMaskAttrName() { return "mask"; }
     VectorType getV1VectorType() {

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c2515f706122e..b576005c47e9b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1803,6 +1803,33 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
   return success();
 }
 
+OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) {
+  Attribute lhs = operands.front(), rhs = operands.back();
+  if (!lhs || !rhs)
+    return {};
+
+  auto lhsType = lhs.getType().cast<VectorType>();
+  // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
+  // manipulation.
+  if (lhsType.getRank() != 1)
+    return {};
+  int64_t lhsSize = lhsType.getDimSize(0);
+
+  SmallVector<Attribute> results;
+  auto lhsElements = lhs.cast<DenseElementsAttr>().getValues<Attribute>();
+  auto rhsElements = rhs.cast<DenseElementsAttr>().getValues<Attribute>();
+  for (const auto &index : this->mask().getAsValueRange<IntegerAttr>()) {
+    int64_t i = index.getZExtValue();
+    if (i >= lhsSize) {
+      results.push_back(rhsElements[i - lhsSize]);
+    } else {
+      results.push_back(lhsElements[i]);
+    }
+  }
+
+  return DenseElementsAttr::get(getVectorType(), results);
+}
+
 //===----------------------------------------------------------------------===//
 // InsertElementOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a25b0687ca756..522f8dea8b470 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1254,3 +1254,15 @@ func @splat_fold() -> vector<4xf32> {
   // CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
   // CHECK-NEXT: return [[V]] : vector<4xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @shuffle_1d
+//       CHECK:   %[[V:.+]] = arith.constant dense<[3, 2, 5, 1]> : vector<4xi32>
+//       CHECK:   return %[[V]]
+func @shuffle_1d() -> vector<4xi32> {
+  %v0 = arith.constant dense<[0, 1, 2]> : vector<3xi32>
+  %v1 = arith.constant dense<[3, 4, 5]> : vector<3xi32>
+  %shuffle = vector.shuffle %v0, %v1 [3, 2, 5, 1] : vector<3xi32>, vector<3xi32>
+  return %shuffle : vector<4xi32>
+}


        


More information about the Mlir-commits mailing list