[Mlir-commits] [mlir] 0a9f6b8 - [mlir][tensor/linalg] Fix bug in reifyResultShapes

Matthias Springer llvmlistbot at llvm.org
Fri Mar 10 02:38:08 PST 2023


Author: Matthias Springer
Date: 2023-03-10T11:37:54+01:00
New Revision: 0a9f6b8ca3fa9449cc0accba1bc11e98d6dbc6b6

URL: https://github.com/llvm/llvm-project/commit/0a9f6b8ca3fa9449cc0accba1bc11e98d6dbc6b6
DIFF: https://github.com/llvm/llvm-project/commit/0a9f6b8ca3fa9449cc0accba1bc11e98d6dbc6b6.diff

LOG: [mlir][tensor/linalg] Fix bug in reifyResultShapes

`reifyResultShapes` should return an IntegerAttr if and only if the corresponding dimension is static.

Differential Revision: https://reviews.llvm.org/D145702

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 6844b685f6fe2..22a3c04ff3745 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -644,10 +644,18 @@ LinalgOp::reifyResultShapes(OpBuilder &b,
   for (OpOperand *opOperand : getDpsInitOperands()) {
     SmallVector<OpFoldResult> shapes;
     for (int64_t dim : llvm::seq<int64_t>(0, getRank(opOperand))) {
-      if (checkDimExpr.visit(shapeExprs[pos]))
-        shapes.push_back(createOrFoldDimOp(b, loc, opOperand->get(), dim));
-      else
-        shapes.push_back(allResultDimValues[pos]);
+      auto shapedType = opOperand->get().getType().cast<ShapedType>();
+      if (!shapedType.isDynamicDim(dim)) {
+        // Static dim: Return IntegerAttr.
+        shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim)));
+      } else {
+        // Dynamic dim: Return Value.
+        OpFoldResult ofr =
+            checkDimExpr.visit(shapeExprs[pos])
+                ? createOrFoldDimOp(b, loc, opOperand->get(), dim)
+                : allResultDimValues[pos];
+        shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
+      }
       pos++;
     }
     reifiedReturnShapes.emplace_back(std::move(shapes));

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index baf213006a6dc..e1bf889088a9e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2205,8 +2205,13 @@ LogicalResult InsertSliceOp::reifyResultShapes(
     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
   reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
   for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
-    reifiedReturnShapes[0][dim] =
-        builder.createOrFold<tensor::DimOp>(getLoc(), getDest(), dim);
+    if (getType().isDynamicDim(dim)) {
+      reifiedReturnShapes[0][dim] =
+          builder.createOrFold<tensor::DimOp>(getLoc(), getDest(), dim);
+    } else {
+      reifiedReturnShapes[0][dim] =
+          builder.getIndexAttr(getType().getDimSize(dim));
+    }
   }
   return success();
 }
@@ -3154,9 +3159,15 @@ reifyResultShapesImpl(OpTy op, OpBuilder &builder,
                 "applies to only pack or unpack operations");
   int64_t destRank = op.getDestRank();
   reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
+  ShapedType resultType = op.getResult().getType().template cast<ShapedType>();
   for (auto dim : llvm::seq<int64_t>(0, destRank)) {
-    reifiedReturnShapes[0][dim] =
-        builder.createOrFold<tensor::DimOp>(op.getLoc(), op.getDest(), dim);
+    if (resultType.isDynamicDim(dim)) {
+      reifiedReturnShapes[0][dim] =
+          builder.createOrFold<tensor::DimOp>(op.getLoc(), op.getDest(), dim);
+    } else {
+      reifiedReturnShapes[0][dim] =
+          builder.getIndexAttr(resultType.getDimSize(dim));
+    }
   }
   return success();
 }


        


More information about the Mlir-commits mailing list