[Mlir-commits] [mlir] [Linalg] Fix SoftmaxOp's reify result shape calculation (PR #67790)

Abhishek Varma llvmlistbot at llvm.org
Fri Sep 29 04:29:19 PDT 2023


https://github.com/Abhishek-Varma created https://github.com/llvm/llvm-project/pull/67790

-- SoftmaxOp's `reifyResultShapes` function was wrongly casting it as a `LinalgOp`.
-- This commit thus adds a fix to SoftmaxOp's reify result shape calculation.

Signed-off-by: Abhishek Varma <abhishek at nod-labs.com>

>From 3bc996d2a5cdea27b2f78afd871d2b4b9d994cb8 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <avarma094 at gmail.com>
Date: Fri, 29 Sep 2023 11:22:24 +0000
Subject: [PATCH] [Linalg] Fix SoftmaxOp's reify result shape calculation

-- SoftmaxOp's `reifyResultShapes` function was wrongly casting
   it as a `LinalgOp`.
-- This commit thus adds a fix to SoftmaxOp's reify result shape
   calculation.

Signed-off-by: Abhishek Varma <abhishek at nod-labs.com>
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 18 ++++++++++++++++--
 1 file changed, 16 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5871c59e1d35d95..7ccf9d39deea105 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2387,8 +2387,22 @@ LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
 LogicalResult
 SoftmaxOp::reifyResultShapes(OpBuilder &b,
                              ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
-  return cast<LinalgOp>(getOperation())
-      .reifyResultShapes(b, reifiedReturnShapes);
+  SmallVector<OpFoldResult> shapes;
+  Location loc = getOperation()->getLoc();
+  IRRewriter rewriter(b);
+  auto shapedType = llvm::cast<ShapedType>(getInputOperandType());
+  for (int64_t dim : llvm::seq<int64_t>(0, getInputOperandRank())) {
+    if (!shapedType.isDynamicDim(dim)) {
+      // Static dim: Return IntegerAttr.
+      shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim)));
+    } else {
+      // Dynamic dim: Return Value.
+      OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
+      shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
+    }
+  }
+  reifiedReturnShapes.emplace_back(std::move(shapes));
+  return success();
 }
 
 void SoftmaxOp::getEffects(



More information about the Mlir-commits mailing list