[Mlir-commits] [mlir] port fixes from local llvm (PR #78484)

ian Bearman llvmlistbot at llvm.org
Thu Jan 18 13:09:37 PST 2024


https://github.com/manbearian updated https://github.com/llvm/llvm-project/pull/78484

>From b83074fa260d2ce4b876b71d507224cb9476a944 Mon Sep 17 00:00:00 2001
From: Ian Bearman <ianb at microsoft.com>
Date: Wed, 17 Jan 2024 17:54:25 +0000
Subject: [PATCH] port fixes from local llvm

---
 .../IR/BufferizableOpInterface.h              | 13 +++++++
 .../BufferizableOpInterfaceImpl.cpp           | 13 ++++---
 .../IR/BufferizableOpInterface.cpp            | 12 +++---
 .../Bufferization/IR/BufferizationOps.cpp     |  4 +-
 .../FuncBufferizableOpInterfaceImpl.cpp       |  4 +-
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 38 ++++++++++++++-----
 .../BufferizableOpInterfaceImpl.cpp           |  8 ++--
 mlir/test/Dialect/Linalg/collapse-dim.mlir    | 14 +++----
 8 files changed, 70 insertions(+), 36 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 63e2d19e68ef97..478cdab8298754 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -257,6 +257,9 @@ struct BufferizationOptions {
   /// Parameters: Value, memory space, bufferization options
   using UnknownTypeConverterFn = std::function<BaseMemRefType(
       Value, Attribute memorySpace, const BufferizationOptions &)>;
+  // Produce a MemorySpace attribute from a tensor type
+  using GetMemorySpaceFn =
+      std::function<std::optional<Attribute>(TensorType t)>;
 
   BufferizationOptions();
 
@@ -351,6 +354,16 @@ struct BufferizationOptions {
   /// used.
   UnknownTypeConverterFn unknownTypeConverterFn = nullptr;
 
+  // Use during type conversion to determine the memory space for memref based
+  // on the originanl tensor type
+  GetMemorySpaceFn getMemorySpaceFn = nullptr;
+
+  std::optional<Attribute> getMemorySpace(TensorType t) const {
+    if (getMemorySpaceFn)
+      return getMemorySpaceFn(t);
+    return defaultMemorySpace;
+  }
+
   /// Seed for the analysis fuzzer. If set to `0`, the fuzzer is deactivated.
   /// Should be used only with `testAnalysisOnly = true`.
   unsigned analysisFuzzerSeed = 0;
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index f69b2557eec922..337ac0c0761440 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -26,17 +26,18 @@ struct ConstantOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto constantOp = cast<arith::ConstantOp>(op);
+    auto type = constantOp.getType().dyn_cast<RankedTensorType>();
+
+    // Only ranked tensors are supported.
+    if (!type)
+      return failure();
 
     Attribute memorySpace;
-    if (options.defaultMemorySpace.has_value())
-      memorySpace = *options.defaultMemorySpace;
+    if (options.getMemorySpace(type))
+      memorySpace = *options.getMemorySpace(type);
     else
       return constantOp->emitError("could not infer memory space");
 
-    // Only ranked tensors are supported.
-    if (!isa<RankedTensorType>(constantOp.getType()))
-      return failure();
-
     // Only constants inside a module are supported.
     auto moduleOp = constantOp->getParentOfType<ModuleOp>();
     if (!moduleOp)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 4b1dfee4a2b926..1a849155abed02 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -682,11 +682,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
     return bufferizableOp.getBufferType(value, options, invocationStack);
 
   // Op is not bufferizable.
-  if (!options.defaultMemorySpace.has_value())
+  auto memSpace = options.getMemorySpace(value.getType().cast<TensorType>());
+  if (!memSpace.has_value())
     return op->emitError("could not infer memory space");
 
-  return getMemRefType(value, options, /*layout=*/{},
-                       *options.defaultMemorySpace);
+  return getMemRefType(value, options, /*layout=*/{}, *memSpace);
 }
 
 bool bufferization::hasTensorSemantics(Operation *op) {
@@ -943,11 +943,11 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
 
   // If we do not know the memory space and there is no default memory space,
   // report a failure.
-  if (!options.defaultMemorySpace.has_value())
+  auto memSpace = options.getMemorySpace(value.getType().cast<TensorType>());
+  if (!memSpace.has_value())
     return op->emitError("could not infer memory space");
 
-  return getMemRefType(value, options, /*layout=*/{},
-                       *options.defaultMemorySpace);
+  return getMemRefType(value, options, /*layout=*/{}, *memSpace);
 }
 
 bool bufferization::detail::defaultIsRepetitiveRegion(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 253fcf2525121b..8618436ff99382 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -234,8 +234,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
     if (failed(copyBufferType))
       return failure();
     memorySpace = copyBufferType->getMemorySpace();
-  } else if (options.defaultMemorySpace.has_value()) {
-    memorySpace = *options.defaultMemorySpace;
+  } else if (auto x = options.getMemorySpace(getType()); x.has_value()) {
+    memorySpace = *x;
   } else {
     return getOperation()->emitError("could not infer memory space");
   }
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 07cd1f90b17df4..4a06bac31961b1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -66,7 +66,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
   assert(tensorType && "expected TensorType");
 
   BaseMemRefType memrefType = options.functionArgTypeConverterFn(
-      tensorType, *options.defaultMemorySpace, funcOp, options);
+      tensorType, *options.getMemorySpace(tensorType), funcOp, options);
 
   auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
       index, BufferizationDialect::kBufferLayoutAttrName);
@@ -443,7 +443,7 @@ struct FuncOpInterface
       // Note: If `inferFunctionResultLayout = true`, cast are later folded
       // away.
       BaseMemRefType resultType = options.functionArgTypeConverterFn(
-          tensorType, *options.defaultMemorySpace, funcOp, options);
+          tensorType, *options.getMemorySpace(tensorType), funcOp, options);
       Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
           loc, resultType, returnVal);
       returnValues.push_back(toMemrefOp);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index b2fe58099b2fb3..5eb7f6ef24721c 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -21,6 +21,7 @@
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/TensorEncoding.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
@@ -1622,7 +1623,20 @@ CollapseShapeOp::inferCollapsedType(RankedTensorType type,
     currentDim += dim;
   }
 
-  return RankedTensorType::get(newShape, type.getElementType());
+  auto encoding = type.getEncoding();
+  if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) {
+    auto ignoreError = [&] {
+      auto emitter = mlir::emitError(UnknownLoc::get(type.getContext()));
+      emitter.abandon();
+      return emitter;
+    };
+    if (failed(
+            v.verifyEncoding(newShape, type.getElementType(), ignoreError))) {
+      // strip the encoding if it is not valid for the new shape.
+      encoding = Attribute();
+    }
+  }
+  return RankedTensorType::get(newShape, type.getElementType(), encoding);
 }
 
 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
@@ -1902,7 +1916,8 @@ RankedTensorType ExtractSliceOp::inferResultType(
   assert(static_cast<int64_t>(staticSizes.size()) ==
              sourceTensorType.getRank() &&
          "unexpected staticSizes not equal to rank of source");
-  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType());
+  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
+                               sourceTensorType.getEncoding());
 }
 
 RankedTensorType ExtractSliceOp::inferResultType(
@@ -1943,7 +1958,8 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
       if (!dimsToProject.test(pos))
         projectedShape.push_back(shape[pos]);
     inferredType =
-        RankedTensorType::get(projectedShape, inferredType.getElementType());
+        RankedTensorType::get(projectedShape, inferredType.getElementType(),
+                              inferredType.getEncoding());
   }
   return inferredType;
 }
@@ -2663,8 +2679,8 @@ struct InsertSliceOpSourceCastInserter final
     if (!hasValidSizesOffsets(newSrcShape))
       return failure();
 
-    RankedTensorType newSrcType =
-        RankedTensorType::get(newSrcShape, srcType.getElementType());
+    RankedTensorType newSrcType = RankedTensorType::get(
+        newSrcShape, srcType.getElementType(), srcType.getEncoding());
     if (srcType == newSrcType ||
         !preservesStaticInformation(srcType, newSrcType) ||
         !tensor::CastOp::areCastCompatible(srcType, newSrcType))
@@ -2815,7 +2831,8 @@ RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
     }
   }
 
-  return RankedTensorType::get(inferredShape, sourceType.getElementType());
+  return RankedTensorType::get(inferredShape, sourceType.getElementType(),
+                               sourceType.getEncoding());
 }
 
 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
@@ -3597,9 +3614,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
         "tiling factors must equal the number of dimensions to tile");
   }
 
-  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
-                              ? packOrUnPack.getDestType()
-                              : packOrUnPack.getSourceType();
+  RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
+                                    ? packOrUnPack.getDestType()
+                                    : packOrUnPack.getSourceType();
   size_t packedRank = packedType.getRank();
   // Require output rank to match input rank + number of blocking factors.
   if (unpackedRank + mixedTiles.size() != packedRank) {
@@ -3866,7 +3883,8 @@ RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
                                          ArrayRef<int64_t> outerDimsPerm) {
   SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
       sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
-  return RankedTensorType::get(resultShape, sourceType.getElementType());
+  return RankedTensorType::get(resultShape, sourceType.getElementType(),
+                               sourceType.getEncoding());
 }
 
 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 2cd57e7324b4dc..907f6bf23b0141 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -473,14 +473,14 @@ struct FromElementsOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
+    auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
 
     // TODO: Implement memory space for this op.
-    if (options.defaultMemorySpace != Attribute())
+    if (options.getMemorySpace(tensorType) != Attribute())
       return op->emitError("memory space not implemented yet");
 
     // Allocate a buffer for the result.
     Location loc = op->getLoc();
-    auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
     auto shape = tensorType.getShape();
     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
     FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
@@ -588,8 +588,10 @@ struct GenerateOpInterface
                           const BufferizationOptions &options) const {
     auto generateOp = cast<tensor::GenerateOp>(op);
 
+    auto type = generateOp.getResult().getType();
+
     // TODO: Implement memory space for this op.
-    if (options.defaultMemorySpace != Attribute())
+    if (options.getMemorySpace(type) != Attribute())
       return op->emitError("memory space not implemented yet");
 
     // Allocate memory.
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index 547320f5338747..dc3b202c8ea9c4 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -122,13 +122,13 @@ func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: mem
 // CHECK-LABEL:   func.func @linalg_copy(
 // CHECK-SAME:                           %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {
-// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32>
-// CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32>
-// CHECK:           %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
-// CHECK:           %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
-// CHECK:           %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32>
-// CHECK:           %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
-// CHECK:           %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
+// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32, 1 : i64>
+// CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
+// CHECK:           %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 1 : i64> into tensor<1x2x60xf32, 1 : i64>
+// CHECK:           %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x60xf32, 3 : i64>
+// CHECK:           %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32, 1 : i64>) outs(%[[VAL_5]] : tensor<1x2x60xf32, 3 : i64>) -> tensor<1x2x60xf32, 3 : i64>
+// CHECK:           %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x60xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
+// CHECK:           %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x3x4x5xf32, 3 : i64>
 // CHECK:           return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64>
 // CHECK:         }
 



More information about the Mlir-commits mailing list