[Mlir-commits] [mlir] Preserve Encoding During TensorOp Creation (PR #80871)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 6 08:43:00 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: ian Bearman (manbearian)

<details>
<summary>Changes</summary>

Preserve the encoding field when deriving new TensorTypes from existing ones.

---
Full diff: https://github.com/llvm/llvm-project/pull/80871.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+28-10) 
- (modified) mlir/test/Dialect/Linalg/collapse-dim.mlir (+7-7) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index b21e89ae3a5713..8f117d9464f5f4 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -21,6 +21,7 @@
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/TensorEncoding.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
@@ -1622,7 +1623,20 @@ CollapseShapeOp::inferCollapsedType(RankedTensorType type,
     currentDim += dim;
   }
 
-  return RankedTensorType::get(newShape, type.getElementType());
+  auto encoding = type.getEncoding();
+  if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) {
+    auto ignoreError = [&] {
+      auto emitter = mlir::emitError(UnknownLoc::get(type.getContext()));
+      emitter.abandon();
+      return emitter;
+    };
+    if (failed(
+            v.verifyEncoding(newShape, type.getElementType(), ignoreError))) {
+      // strip the encoding if it is not valid for the new shape.
+      encoding = Attribute();
+    }
+  }
+  return RankedTensorType::get(newShape, type.getElementType(), encoding);
 }
 
 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
@@ -1902,7 +1916,8 @@ RankedTensorType ExtractSliceOp::inferResultType(
   assert(static_cast<int64_t>(staticSizes.size()) ==
              sourceTensorType.getRank() &&
          "unexpected staticSizes not equal to rank of source");
-  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType());
+  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
+                               sourceTensorType.getEncoding());
 }
 
 RankedTensorType ExtractSliceOp::inferResultType(
@@ -1943,7 +1958,8 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
       if (!dimsToProject.test(pos))
         projectedShape.push_back(shape[pos]);
     inferredType =
-        RankedTensorType::get(projectedShape, inferredType.getElementType());
+        RankedTensorType::get(projectedShape, inferredType.getElementType(),
+                              inferredType.getEncoding());
   }
   return inferredType;
 }
@@ -2663,8 +2679,8 @@ struct InsertSliceOpSourceCastInserter final
     if (!hasValidSizesOffsets(newSrcShape))
       return failure();
 
-    RankedTensorType newSrcType =
-        RankedTensorType::get(newSrcShape, srcType.getElementType());
+    RankedTensorType newSrcType = RankedTensorType::get(
+        newSrcShape, srcType.getElementType(), srcType.getEncoding());
     if (srcType == newSrcType ||
         !preservesStaticInformation(srcType, newSrcType) ||
         !tensor::CastOp::areCastCompatible(srcType, newSrcType))
@@ -2815,7 +2831,8 @@ RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
     }
   }
 
-  return RankedTensorType::get(inferredShape, sourceType.getElementType());
+  return RankedTensorType::get(inferredShape, sourceType.getElementType(),
+                               sourceType.getEncoding());
 }
 
 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
@@ -3601,9 +3618,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
         "tiling factors must equal the number of dimensions to tile");
   }
 
-  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
-                              ? packOrUnPack.getDestType()
-                              : packOrUnPack.getSourceType();
+  RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
+                                    ? packOrUnPack.getDestType()
+                                    : packOrUnPack.getSourceType();
   size_t packedRank = packedType.getRank();
   // Require output rank to match input rank + number of blocking factors.
   if (unpackedRank + mixedTiles.size() != packedRank) {
@@ -3870,7 +3887,8 @@ RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
                                          ArrayRef<int64_t> outerDimsPerm) {
   SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
       sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
-  return RankedTensorType::get(resultShape, sourceType.getElementType());
+  return RankedTensorType::get(resultShape, sourceType.getElementType(),
+                               sourceType.getEncoding());
 }
 
 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index 547320f5338747..dc3b202c8ea9c4 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -122,13 +122,13 @@ func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: mem
 // CHECK-LABEL:   func.func @linalg_copy(
 // CHECK-SAME:                           %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {
-// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32>
-// CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32>
-// CHECK:           %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
-// CHECK:           %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
-// CHECK:           %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32>
-// CHECK:           %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
-// CHECK:           %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
+// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32, 1 : i64>
+// CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
+// CHECK:           %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 1 : i64> into tensor<1x2x60xf32, 1 : i64>
+// CHECK:           %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x60xf32, 3 : i64>
+// CHECK:           %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32, 1 : i64>) outs(%[[VAL_5]] : tensor<1x2x60xf32, 3 : i64>) -> tensor<1x2x60xf32, 3 : i64>
+// CHECK:           %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x60xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
+// CHECK:           %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x3x4x5xf32, 3 : i64>
 // CHECK:           return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64>
 // CHECK:         }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/80871


More information about the Mlir-commits mailing list