[Mlir-commits] [mlir] 070d211 - [mlir][Linalg] Fix SoftmaxOp's reify result shape calculation (#67790)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 29 10:55:40 PDT 2023


Author: Abhishek Varma
Date: 2023-09-29T10:55:35-07:00
New Revision: 070d2114b15e7eefba55516e4689028b6079415f

URL: https://github.com/llvm/llvm-project/commit/070d2114b15e7eefba55516e4689028b6079415f
DIFF: https://github.com/llvm/llvm-project/commit/070d2114b15e7eefba55516e4689028b6079415f.diff

LOG: [mlir][Linalg] Fix SoftmaxOp's reify result shape calculation (#67790)

-- SoftmaxOp's `reifyResultShapes` function was wrongly casting it as a
`LinalgOp`.
-- This commit thus adds a fix to SoftmaxOp's reify result shape
calculation.

Signed-off-by: Abhishek Varma <abhishek at nod-labs.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5871c59e1d35d95..491f4a66574616e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2387,8 +2387,23 @@ LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
 LogicalResult
 SoftmaxOp::reifyResultShapes(OpBuilder &b,
                              ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
-  return cast<LinalgOp>(getOperation())
-      .reifyResultShapes(b, reifiedReturnShapes);
+  SmallVector<OpFoldResult> shapes;
+  Location loc = getOperation()->getLoc();
+  IRRewriter rewriter(b);
+  auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
+  auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
+  for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
+    if (!outputShapedType.isDynamicDim(dim)) {
+      // Static dim: Return IntegerAttr.
+      shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
+    } else {
+      // Dynamic dim: Return Value.
+      OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
+      shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
+    }
+  }
+  reifiedReturnShapes.emplace_back(std::move(shapes));
+  return success();
 }
 
 void SoftmaxOp::getEffects(

diff  --git a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
index aeb357d4ee86d7a..4262cd23e7469d6 100644
--- a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
@@ -276,3 +276,22 @@ func.func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index
 //      CHECK:   %[[IN_DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
 //      CHECK:   %[[OUT_DIM2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[IN_DIM2]]]
 //      CHECK:   return %[[C12]], %[[OUT_DIM1]], %[[OUT_DIM2]]
+
+// -----
+
+func.func @dim_of_softmax_op(%arg0: tensor<?x16x?xf32>, %arg1: tensor<2x?x?xf32>) -> (index, index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %0 = linalg.softmax dimension(2) ins(%arg0 : tensor<?x16x?xf32>) outs(%arg1 : tensor<2x?x?xf32>) -> tensor<2x?x?xf32>
+  %dim = tensor.dim %0, %c0 : tensor<2x?x?xf32>
+  %dim_0 = tensor.dim %0, %c1 : tensor<2x?x?xf32>
+  %dim_1 = tensor.dim %0, %c2 : tensor<2x?x?xf32>
+  return %dim, %dim_0, %dim_1 : index, index, index
+}
+// CHECK-LABEL: @dim_of_softmax_op
+// CHECK-SAME:  (%[[INPUT:.*]]: tensor<?x16x?xf32>
+// CHECK-NEXT:      %[[C2:.*]] = arith.constant 2 : index
+// CHECK-NEXT:      %[[C16:.*]] = arith.constant 16 : index
+// CHECK-NEXT:      %[[IN_DIM2:.*]] = tensor.dim %[[INPUT]], %[[C2]] : tensor<?x16x?xf32>
+// CHECK-NEXT:      return %[[C2]], %[[C16]], %[[IN_DIM2]] : index, index, index


        


More information about the Mlir-commits mailing list