[Mlir-commits] [mlir] ba95bf7 - [mlir][tensor] Add getMixedSizes helper

Matthias Springer llvmlistbot at llvm.org
Thu Aug 25 01:29:32 PDT 2022


Author: Matthias Springer
Date: 2022-08-25T10:25:41+02:00
New Revision: ba95bf765d081f64284e90eba852a4fe5469ff48

URL: https://github.com/llvm/llvm-project/commit/ba95bf765d081f64284e90eba852a4fe5469ff48
DIFF: https://github.com/llvm/llvm-project/commit/ba95bf765d081f64284e90eba852a4fe5469ff48.diff

LOG: [mlir][tensor] Add getMixedSizes helper

This helper function computes the dimensions of a tensor value as OpFoldResults.

Differential Revision: https://reviews.llvm.org/D132475

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 43d8ba1394012..0622c0ed32197 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -110,6 +110,10 @@ bool canFoldIntoProducerOp(CastOp castOp);
 /// that can be folded.
 LogicalResult foldTensorCast(Operation *op);
 
+/// Return the dimensions of the given tensor value.
+SmallVector<OpFoldResult> getMixedSizes(OpBuilder &builder, Location loc,
+                                        Value value);
+
 /// Create a rank-reducing ExtractSliceOp @[0 .. 0] with strides [1 .. 1] and
 /// appropriate sizes (i.e. `tensor.getSizes()`) to reduce the rank of `tensor`
 /// to that of `targetType`.

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 2d91f45205e6c..060e2fd9207aa 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -36,6 +36,21 @@ Operation *TensorDialect::materializeConstant(OpBuilder &builder,
   return nullptr;
 }
 
+SmallVector<OpFoldResult> tensor::getMixedSizes(OpBuilder &builder,
+                                                Location loc, Value value) {
+  auto tensorType = value.getType().cast<RankedTensorType>();
+  SmallVector<OpFoldResult> result;
+  for (int64_t i = 0; i < tensorType.getRank(); ++i) {
+    if (tensorType.isDynamicDim(i)) {
+      Value size = builder.create<tensor::DimOp>(loc, value, i);
+      result.push_back(size);
+    } else {
+      result.push_back(builder.getIndexAttr(tensorType.getDimSize(i)));
+    }
+  }
+  return result;
+}
+
 //===----------------------------------------------------------------------===//
 // CastOp
 //===----------------------------------------------------------------------===//
@@ -1465,18 +1480,8 @@ Value mlir::tensor::createCanonicalRankReducingExtractSliceOp(
     OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
   auto rankedTensorType = tensor.getType().cast<RankedTensorType>();
   unsigned rank = rankedTensorType.getRank();
-  auto shape = rankedTensorType.getShape();
   SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
-  SmallVector<OpFoldResult> sizes;
-  for (unsigned i = 0, e = rank; i < e; ++i) {
-    OpFoldResult dim;
-    if (rankedTensorType.isDynamicDim(i))
-      dim = b.createOrFold<tensor::DimOp>(
-          loc, tensor, b.create<arith::ConstantIndexOp>(loc, i));
-    else
-      dim = b.getIndexAttr(shape[i]);
-    sizes.push_back(dim);
-  }
+  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, tensor);
   SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
   return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
                                                 offsets, sizes, strides);
@@ -1818,18 +1823,8 @@ Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b,
                                                              Value dest) {
   auto rankedTensorType = dest.getType().cast<RankedTensorType>();
   unsigned rank = rankedTensorType.getRank();
-  auto shape = rankedTensorType.getShape();
   SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
-  SmallVector<OpFoldResult> sizes;
-  for (unsigned i = 0, e = rank; i < e; ++i) {
-    OpFoldResult dim;
-    if (rankedTensorType.isDynamicDim(i))
-      dim = b.createOrFold<tensor::DimOp>(
-          loc, dest, b.create<arith::ConstantIndexOp>(loc, i));
-    else
-      dim = b.getIndexAttr(shape[i]);
-    sizes.push_back(dim);
-  }
+  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
   SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
   return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
                                                sizes, strides);

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 3600524ce7e22..cf9e5c89dfaa5 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -812,16 +812,10 @@ struct PadOpInterface
                                 generateOp.getBody().begin());
 
     // Create tensor::InsertSliceOp.
-    SmallVector<OpFoldResult> sliceSizes, sliceStrides;
-    for (int64_t i = 0; i < resultType.getRank(); ++i) {
-      sliceStrides.push_back(rewriter.getIndexAttr(1));
-      if (srcType.isDynamicDim(i)) {
-        Value size = rewriter.create<tensor::DimOp>(loc, padOp.getSource(), i);
-        sliceSizes.push_back(size);
-      } else {
-        sliceSizes.push_back(rewriter.getIndexAttr(srcType.getDimSize(i)));
-      }
-    }
+    SmallVector<OpFoldResult> sliceSizes =
+        getMixedSizes(rewriter, loc, padOp.getSource());
+    SmallVector<OpFoldResult> sliceStrides(srcType.getRank(),
+                                           rewriter.getIndexAttr(1));
     rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
         padOp, padOp.getSource(), generateOp.getResult(),
         /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);


        


More information about the Mlir-commits mailing list