[Mlir-commits] [mlir] 51afca6 - [mlir][vector] Simplify fold pattern for ExtractOp(constant). NFC.

Jakub Kuderski llvmlistbot at llvm.org
Wed Nov 23 15:58:27 PST 2022


Author: Jakub Kuderski
Date: 2022-11-23T18:57:21-05:00
New Revision: 51afca640c2968c51a92dfb89e67e10bdcb98216

URL: https://github.com/llvm/llvm-project/commit/51afca640c2968c51a92dfb89e67e10bdcb98216
DIFF: https://github.com/llvm/llvm-project/commit/51afca640c2968c51a92dfb89e67e10bdcb98216.diff

LOG: [mlir][vector] Simplify fold pattern for ExtractOp(constant). NFC.

Use helper functions. Reuse array element attributes.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D138609

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bd96ee7de24f7..22d7bdc3542e4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1623,7 +1623,6 @@ class ExtractOpScalarVectorConstantFolder final
       return failure();
 
     auto vecTy = sourceVector.getType().cast<VectorType>();
-    Type elemTy = vecTy.getElementType();
     ArrayAttr positions = extractOp.getPosition();
     if (vecTy.isScalable())
       return failure();
@@ -1631,36 +1630,17 @@ class ExtractOpScalarVectorConstantFolder final
     // constants.
     if (vecTy.getRank() != static_cast<int64_t>(positions.size()))
       return failure();
-    // TODO: Handle more element types, e.g., complex values.
-    if (!elemTy.isIntOrIndexOrFloat())
-      return failure();
 
     // The splat case is handled by `ExtractOpSplatConstantFolder`.
     auto dense = vectorCst.dyn_cast<DenseElementsAttr>();
     if (!dense || dense.isSplat())
       return failure();
 
-    // Calculate the flattened position.
-    int64_t elemPosition = 0;
-    int64_t innerElems = 1;
-    for (auto [dimSize, positionInDim] :
-         llvm::reverse(llvm::zip(vecTy.getShape(), positions))) {
-      int64_t positionVal = positionInDim.cast<IntegerAttr>().getInt();
-      elemPosition += positionVal * innerElems;
-      innerElems *= dimSize;
-    }
-
-    Attribute newAttr;
-    if (vecTy.getElementType().isIntOrIndex()) {
-      auto values = to_vector(dense.getValues<APInt>());
-      newAttr = IntegerAttr::get(extractOp.getType(), values[elemPosition]);
-    } else if (vecTy.getElementType().isa<FloatType>()) {
-      auto values = to_vector(dense.getValues<APFloat>());
-      newAttr = FloatAttr::get(extractOp.getType(), values[elemPosition]);
-    }
-    assert(newAttr && "Unhandled case");
-
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
+    // Calculate the linearized position.
+    int64_t elemPosition =
+        linearize(getI64SubArray(positions), computeStrides(vecTy.getShape()));
+    Attribute elementValue = *(dense.value_begin<Attribute>() + elemPosition);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, elementValue);
     return success();
   }
 };


        


More information about the Mlir-commits mailing list