[Mlir-commits] [mlir] b1f2e26 - [mlir][tosa] Switch TosaFoldConstantTranspose to use ElementsAttr.

Jacques Pienaar llvmlistbot at llvm.org
Mon Aug 22 15:45:33 PDT 2022


Author: Jacques Pienaar
Date: 2022-08-22T15:45:23-07:00
New Revision: b1f2e2664ef059be834806c37eab32248a9f396f

URL: https://github.com/llvm/llvm-project/commit/b1f2e2664ef059be834806c37eab32248a9f396f
DIFF: https://github.com/llvm/llvm-project/commit/b1f2e2664ef059be834806c37eab32248a9f396f.diff

LOG: [mlir][tosa] Switch TosaFoldConstantTranspose to use ElementsAttr.

Also avoid redoing index calculation.

Differential Revision: https://reviews.llvm.org/D132274

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp
index f86605da54ef2..3147073819235 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp
@@ -30,7 +30,7 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
     if (!outputType.getElementType().isIntOrIndexOrFloat())
       return failure();
 
-    DenseElementsAttr inputValues;
+    ElementsAttr inputValues;
     if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
       return failure();
     // Make sure the input is a constant that has a single user.
@@ -57,10 +57,9 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
     // index.
     auto attrValues = inputValues.getValues<Attribute>();
     ArrayRef<int64_t> outputShape = outputType.getShape();
-    for (int srcLinearIndex = 0; srcLinearIndex < numElements;
-         ++srcLinearIndex) {
+    for (const auto &it : llvm::enumerate(attrValues)) {
       SmallVector<uint64_t, 6> srcIndices(inputType.getRank(), 0);
-      int totalCount = srcLinearIndex;
+      int totalCount = it.index();
       for (int dim = inputType.getRank() - 1; dim >= 0; --dim) {
         srcIndices[dim] = totalCount % inputShape[dim];
         totalCount /= inputShape[dim];
@@ -74,7 +73,7 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
       for (int dim = 1; dim < outputType.getRank(); ++dim)
         dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
 
-      outputValues[dstLinearIndex] = attrValues[srcIndices];
+      outputValues[dstLinearIndex] = it.value();
     }
 
     rewriter.replaceOpWithNewOp<tosa::ConstOp>(


        


More information about the Mlir-commits mailing list