[Mlir-commits] [mlir] [MLIR][Vector] Add Lowering for vector.step (PR #113655)

Manupa Karunaratne llvmlistbot at llvm.org
Tue Oct 29 05:07:22 PDT 2024


https://github.com/manupak updated https://github.com/llvm/llvm-project/pull/113655

>From 2a9ce906d442703eac436170968b2597a505930d Mon Sep 17 00:00:00 2001
From: Manupa Karunaratne <manupa.karunaratne at amd.com>
Date: Fri, 25 Oct 2024 05:45:39 +0000
Subject: [PATCH 1/4] [MLIR][Vector] Add Lowering for vector.step

Currently, the lowering for vector.step lives
under a folder. This is not ideal if we want
to do transformation on it and defer the
materizaliztion of the constants much later.

This commits adds a rewrite pattern + transform
op to do this instead. Thus enabling more control
on the lowering.
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  1 -
 .../Vector/TransformOps/VectorTransformOps.td |  9 +++
 .../Vector/Transforms/LoweringPatterns.h      |  7 ++
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      |  1 +
 .../Transforms/SparseVectorization.cpp        |  2 +
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 14 ----
 .../TransformOps/VectorTransformOps.cpp       |  5 ++
 .../Dialect/Vector/Transforms/CMakeLists.txt  |  1 +
 .../Vector/Transforms/LowerVectorStep.cpp     | 64 +++++++++++++++
 .../Linalg/vectorization-scalable.mlir        |  1 +
 .../Linalg/vectorization-with-patterns.mlir   | 15 ++++
 .../Linalg/vectorize-tensor-extract.mlir      | 79 +++++++++++++++++--
 mlir/test/Dialect/Vector/canonicalize.mlir    |  9 ---
 13 files changed, 178 insertions(+), 30 deletions(-)
 create mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index c02b16ea931706..5e7b6659548203 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2940,7 +2940,6 @@ def Vector_StepOp : Vector_Op<"step", [Pure]> {
     %1 = vector.step : vector<[4]xindex> ; [0, 1, .., <vscale * 4 - 1>]
     ```
   }];
-  let hasFolder = 1;
   let results = (outs VectorOfRankAndType<[1], [Index]>:$result);
   let assemblyFormat = "attr-dict `:` type($result)";
 }
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index c973eca0132a92..3262aa37a81877 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -453,4 +453,13 @@ def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyLowerStepToArithOps : Op<Transform_Dialect,
+    "apply_patterns.vector.step_to_arith",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Convert vector.step to arith if not using scalable vectors.
+  }];
+  let assemblyFormat = "attr-dict";
+}
+
 #endif // VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 1976b8399c7f9c..27581443814322 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -235,6 +235,13 @@ void populateVectorTransferPermutationMapLoweringPatterns(
 void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
                                         PatternBenefit benefit = 1);
 
+/// Populate the pattern set with the following patterns:
+///
+/// [StepToArithOps]
+/// Convert vector.step op into arith ops if not scalable
+void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
+                                        PatternBenefit benefit = 1);
+
 /// Populate the pattern set with the following patterns:
 ///
 /// [FlattenGather]
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 984af50a7b0a51..15a545dbb42180 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1886,6 +1886,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
   MLIRContext *ctx = converter.getDialect()->getContext();
   patterns.add<VectorFMAOpNDRewritePattern>(ctx);
   populateVectorInsertExtractStridedSliceTransforms(patterns);
+  populateVectorStepLoweringPatterns(patterns);
   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
   patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
   patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index d1c95dabd88a5e..b2eca539194a87 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -27,6 +27,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/IR/Matchers.h"
 
 using namespace mlir;
@@ -664,6 +665,7 @@ void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
                                                bool enableVLAVectorization,
                                                bool enableSIMDIndex32) {
   assert(vectorLength > 0);
+  vector::populateVectorStepLoweringPatterns(patterns);
   patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
                               enableVLAVectorization, enableSIMDIndex32);
   patterns.add<ReducChainRewriter<vector::InsertElementOp>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a2abe1619454f2..2daf3e8a29ff9c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6423,20 +6423,6 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
   return SplatElementsAttr::get(getType(), {constOperand});
 }
 
-//===----------------------------------------------------------------------===//
-// StepOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult StepOp::fold(FoldAdaptor adaptor) {
-  auto resultType = cast<VectorType>(getType());
-  if (resultType.isScalable())
-    return nullptr;
-  SmallVector<APInt> indices;
-  for (unsigned i = 0; i < resultType.getNumElements(); i++)
-    indices.push_back(APInt(/*width=*/64, i));
-  return DenseElementsAttr::get(resultType, indices);
-}
-
 //===----------------------------------------------------------------------===//
 // WarpExecuteOnLane0Op
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 241e83e234d621..b6f49b85c6205a 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -207,6 +207,11 @@ void transform::ApplyTransferToScfPatternsOp::populatePatterns(
   populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
 }
 
+void transform::ApplyLowerStepToArithOps::populatePatterns(
+    RewritePatternSet &patterns) {
+  populateVectorStepLoweringPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index b7e8724c3c2582..9a3bd5d4593d63 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorMultiReduction.cpp
   LowerVectorScan.cpp
   LowerVectorShapeCast.cpp
+  LowerVectorStep.cpp
   LowerVectorTransfer.cpp
   LowerVectorTranspose.cpp
   SubsetOpInterfaceImpl.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp
new file mode 100644
index 00000000000000..fb7a516e4b41c3
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp
@@ -0,0 +1,64 @@
+//===- LowerVectorStep.cpp - Lower 'vector.step' operation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.step' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
+
+#define DEBUG_TYPE "vector-step-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+struct StepToArithOps : public OpRewritePattern<vector::StepOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::StepOp stepOp,
+                                PatternRewriter &rewriter) const override {
+    auto resultType = cast<VectorType>(stepOp.getType());
+    if (!resultType.isScalable()) {
+      SmallVector<APInt> indices;
+      for (unsigned i = 0; i < resultType.getNumElements(); i++)
+        indices.push_back(APInt(/*width=*/64, i));
+      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+          stepOp, DenseElementsAttr::get(resultType, indices));
+      return success();
+    }
+    return failure();
+  }
+};
+} // namespace
+
+void mlir::vector::populateVectorStepLoweringPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<StepToArithOps>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
index c3a30e3ee209e8..96866885df4d49 100644
--- a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
@@ -180,6 +180,7 @@ module attributes {transform.with_named_sequence} {
     %func = transform.structured.match ops{["func.func"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
+      transform.apply_patterns.vector.step_to_arith
       transform.apply_patterns.canonicalization
       transform.apply_patterns.linalg.tiling_canonicalization
     } : !transform.any_op
diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
index 1c6a786bfa436d..8a3dbe6765ebd8 100644
--- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
@@ -346,6 +346,11 @@ module attributes {transform.with_named_sequence} {
      %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
      %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
      %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
+     %func = transform.structured.match ops{["func.func"]} in %arg1
+       : (!transform.any_op) -> !transform.any_op
+     transform.apply_patterns to %func {
+       transform.apply_patterns.vector.step_to_arith
+     } : !transform.any_op
      transform.yield
   }
 }
@@ -474,6 +479,11 @@ module attributes {transform.with_named_sequence} {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+    %func = transform.structured.match ops{["func.func"]} in %arg1
+       : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+       transform.apply_patterns.vector.step_to_arith
+    } : !transform.any_op
     transform.yield
   }
 }
@@ -505,6 +515,11 @@ module attributes {transform.with_named_sequence} {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+    %func = transform.structured.match ops{["func.func"]} in %arg1
+       : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.step_to_arith
+    } : !transform.any_op
     transform.yield
   }
 }
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index e611a8e22ee23f..2cf7804264ef8c 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -32,6 +32,11 @@ module attributes {transform.with_named_sequence} {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+    %func = transform.structured.match ops{["func.func"]} in %arg1
+       : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+       transform.apply_patterns.vector.step_to_arith
+    } : !transform.any_op
     transform.yield
   }
 }
@@ -172,6 +177,11 @@ module attributes {transform.with_named_sequence} {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
+    %func = transform.structured.match ops{["func.func"]} in %arg1
+       : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+       transform.apply_patterns.vector.step_to_arith
+    } : !transform.any_op
     transform.yield
   }
 }
@@ -207,8 +217,9 @@ func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16
 // CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 79 : index
-// CHECK:           %[[VAL_11:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
-// CHECK:           %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : index to vector<1x4xindex>
+// CHECK-DAG:       %[[VAL_1_BCAST:.*]] = vector.broadcast %[[VAL_1]] : index to vector<1x4xindex>
+// CHECK-DAG:       %[[VAL_2_BCAST:.*]] = vector.broadcast %[[VAL_2]] : index to vector<1x4xindex>
+// CHECK:           %[[VAL_12:.*]] = arith.addi %[[VAL_1_BCAST]], %[[VAL_2_BCAST]] : vector<1x4xindex>
 // CHECK:           %[[VAL_13:.*]] = vector.broadcast %[[VAL_3]] : index to vector<4xindex>
 // CHECK:           %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_6]] : vector<4xindex>
 // CHECK:           %[[VAL_15:.*]] = vector.broadcast %[[VAL_4]] : index to vector<4xindex>
@@ -226,6 +237,11 @@ module attributes {transform.with_named_sequence} {
      %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
      %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
      %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
+     %func = transform.structured.match ops{["func.func"]} in %arg1
+       : (!transform.any_op) -> !transform.any_op
+     transform.apply_patterns to %func {
+       transform.apply_patterns.vector.step_to_arith
+     } : !transform.any_op
      transform.yield
    }
 }
@@ -306,6 +322,11 @@ module attributes {transform.with_named_sequence} {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
+    %func = transform.structured.match ops{["func.func"]} in %arg0
+       : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.step_to_arith
+    } : !transform.any_op
     transform.yield
   }
 }
@@ -321,8 +342,9 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
 // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
 // CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
-// CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
-// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<1xindex>
+// CHECK: %[[ARG1_BCAST0:.*]] = vector.broadcast %arg1 : index to vector<1xindex>
+// CHECK: %[[ARG1_BCAST1:.*]] = vector.broadcast %arg1 : index to vector<1xindex>
+// CHECK: %[[B2:.*]] = arith.addi %[[ARG1_BCAST0]], %[[ARG1_BCAST1]] : vector<1xindex>
 // CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
 // CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
 // CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
@@ -357,17 +379,22 @@ module attributes {transform.with_named_sequence} {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
+    %func = transform.structured.match ops{["func.func"]} in %arg2
+       : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.step_to_arith
+    } : !transform.any_op
     transform.yield
   }
 }
 
 // CHECK-LABEL:   func.func @index_from_output_column_vector_gather_load(
 // CHECK-SAME:      %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> {
+// CHECK:           %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
 // CHECK:           %[[C128:.*]] = arith.constant dense<128> : vector<1x8xindex>
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
 // CHECK:           %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1>
-// CHECK:           %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
 // CHECK:           %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32>
 // CHECK:           %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex>
 // CHECK:           %[[MUL:.*]] = arith.muli %[[B]], %[[C128]] : vector<1x8xindex>
@@ -404,16 +431,21 @@ module attributes {transform.with_named_sequence} {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
+    %func = transform.structured.match ops{["func.func"]} in %arg2
+       : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.step_to_arith
+    } : !transform.any_op
     transform.yield
   }
 }
 
 // CHECK-LABEL:   func.func @index_from_output_column_vector_contiguous_load(
 // CHECK-SAME:      %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> {
+// CHECK:           %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
 // CHECK:           %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1>
-// CHECK:           %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
 // CHECK:           %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32>
 // CHECK:           %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex>
 // CHECK:           %[[TR:.*]] = vector.transpose %[[B]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
@@ -464,6 +496,11 @@ module attributes {transform.with_named_sequence} {
      %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
      %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
      %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
+     %func = transform.structured.match ops{["func.func"]} in %arg1
+       : (!transform.any_op) -> !transform.any_op
+     transform.apply_patterns to %func {
+      transform.apply_patterns.vector.step_to_arith
+     } : !transform.any_op
      transform.yield
    }
 }
@@ -508,6 +545,11 @@ module attributes {transform.with_named_sequence} {
      %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
      %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
      %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
+     %func = transform.structured.match ops{["func.func"]} in %arg1
+       : (!transform.any_op) -> !transform.any_op
+     transform.apply_patterns to %func {
+      transform.apply_patterns.vector.step_to_arith
+     } : !transform.any_op
      transform.yield
    }
 }
@@ -599,6 +641,11 @@ module attributes {transform.with_named_sequence} {
      %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
      %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
      %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
+     %func = transform.structured.match ops{["func.func"]} in %arg1
+       : (!transform.any_op) -> !transform.any_op
+     transform.apply_patterns to %func {
+      transform.apply_patterns.vector.step_to_arith
+     } : !transform.any_op
      transform.yield
    }
 }
@@ -641,6 +688,11 @@ module attributes {transform.with_named_sequence} {
      %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
      %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
      %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
+     %func = transform.structured.match ops{["func.func"]} in %arg1
+       : (!transform.any_op) -> !transform.any_op
+     transform.apply_patterns to %func {
+      transform.apply_patterns.vector.step_to_arith
+     } : !transform.any_op
      transform.yield
    }
 }
@@ -683,6 +735,11 @@ module attributes {transform.with_named_sequence} {
      %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
      %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
      %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
+     %func = transform.structured.match ops{["func.func"]} in %arg1
+       : (!transform.any_op) -> !transform.any_op
+     transform.apply_patterns to %func {
+      transform.apply_patterns.vector.step_to_arith
+     } : !transform.any_op
      transform.yield
    }
 }
@@ -724,6 +781,11 @@ module attributes {transform.with_named_sequence} {
      %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
      %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
      %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
+     %func = transform.structured.match ops{["func.func"]} in %arg1
+       : (!transform.any_op) -> !transform.any_op
+     transform.apply_patterns to %func {
+      transform.apply_patterns.vector.step_to_arith
+     } : !transform.any_op
      transform.yield
    }
 }
@@ -798,6 +860,11 @@ module attributes {transform.with_named_sequence} {
      %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
      %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
      %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
+     %func = transform.structured.match ops{["func.func"]} in %arg1
+       : (!transform.any_op) -> !transform.any_op
+     transform.apply_patterns to %func {
+      transform.apply_patterns.vector.step_to_arith
+     } : !transform.any_op
      transform.yield
    }
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 6d6bc199e601c0..3f079c486e5ca4 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2722,15 +2722,6 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
   return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
 }
 
-// -----
-
-// CHECK-LABEL: @fold_vector_step_to_constant
-// CHECK: %[[CONSTANT:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK: return %[[CONSTANT]] : vector<4xindex>
-func.func @fold_vector_step_to_constant() -> vector<4xindex> {
-  %0 = vector.step : vector<4xindex>
-  return %0 : vector<4xindex>
-}
 
 // -----
 

>From 136b6582cbcc641dcd4f59040505ac82977f3731 Mon Sep 17 00:00:00 2001
From: Manupa Karunaratne <manupa.karunaratne at amd.com>
Date: Mon, 28 Oct 2024 14:54:21 +0000
Subject: [PATCH 2/4] * remove unnecessary headers * cleanup step rewrite
 pattern * rename step rewriter pattern + transform op

---
 .../Vector/TransformOps/VectorTransformOps.td |  4 +-
 .../Vector/Transforms/LoweringPatterns.h      |  2 +-
 .../TransformOps/VectorTransformOps.cpp       |  2 +-
 .../Vector/Transforms/LowerVectorStep.cpp     | 37 ++++++-------------
 .../Linalg/vectorization-scalable.mlir        |  2 +-
 .../Linalg/vectorization-with-patterns.mlir   |  6 +--
 .../Linalg/vectorize-tensor-extract.mlir      | 26 ++++++-------
 7 files changed, 32 insertions(+), 47 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 3262aa37a81877..79fe8401359a5b 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -453,8 +453,8 @@ def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
-def ApplyLowerStepToArithOps : Op<Transform_Dialect,
-    "apply_patterns.vector.step_to_arith",
+def ApplyLowerStepToArithConstantOp : Op<Transform_Dialect,
+    "apply_patterns.vector.step_to_arith_constant",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
     Convert vector.step to arith if not using scalable vectors.
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 27581443814322..80e4186664a640 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -238,7 +238,7 @@ void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
 /// Populate the pattern set with the following patterns:
 ///
 /// [StepToArithOps]
-/// Convert vector.step op into arith ops if not scalable
+/// Convert vector.step op into arith ops if not using scalable vectors
 void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
                                         PatternBenefit benefit = 1);
 
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index b6f49b85c6205a..dd981a9f0ad4bd 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -207,7 +207,7 @@ void transform::ApplyTransferToScfPatternsOp::populatePatterns(
   populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
 }
 
-void transform::ApplyLowerStepToArithOps::populatePatterns(
+void transform::ApplyLowerStepToArithConstantOp::populatePatterns(
     RewritePatternSet &patterns) {
   populateVectorStepLoweringPatterns(patterns);
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp
index fb7a516e4b41c3..f0601e07756b92 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp
@@ -11,26 +11,10 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
-#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
-#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/IR/BuiltinAttributeInterfaces.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Interfaces/VectorInterfaces.h"
 
 #define DEBUG_TYPE "vector-step-lowering"
 
@@ -39,26 +23,27 @@ using namespace mlir::vector;
 
 namespace {
 
-struct StepToArithOps : public OpRewritePattern<vector::StepOp> {
+struct StepToArithConstantOp final : OpRewritePattern<vector::StepOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::StepOp stepOp,
                                 PatternRewriter &rewriter) const override {
     auto resultType = cast<VectorType>(stepOp.getType());
-    if (!resultType.isScalable()) {
-      SmallVector<APInt> indices;
-      for (unsigned i = 0; i < resultType.getNumElements(); i++)
-        indices.push_back(APInt(/*width=*/64, i));
-      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
-          stepOp, DenseElementsAttr::get(resultType, indices));
-      return success();
+    if (resultType.isScalable()) {
+      return failure();
     }
-    return failure();
+    int64_t elementCount = resultType.getNumElements();
+    SmallVector<APInt> indices =
+        llvm::map_to_vector(llvm::seq(elementCount),
+                            [](int64_t i) { return APInt(/*width=*/64, i); });
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+        stepOp, DenseElementsAttr::get(resultType, indices));
+    return success();
   }
 };
 } // namespace
 
 void mlir::vector::populateVectorStepLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<StepToArithOps>(patterns.getContext(), benefit);
+  patterns.add<StepToArithConstantOp>(patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
index 96866885df4d49..6709cffbbce6de 100644
--- a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
@@ -180,7 +180,7 @@ module attributes {transform.with_named_sequence} {
     %func = transform.structured.match ops{["func.func"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
-      transform.apply_patterns.vector.step_to_arith
+      transform.apply_patterns.vector.step_to_arith_constant
       transform.apply_patterns.canonicalization
       transform.apply_patterns.linalg.tiling_canonicalization
     } : !transform.any_op
diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
index 8a3dbe6765ebd8..a6f822cbced744 100644
--- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
@@ -349,7 +349,7 @@ module attributes {transform.with_named_sequence} {
      %func = transform.structured.match ops{["func.func"]} in %arg1
        : (!transform.any_op) -> !transform.any_op
      transform.apply_patterns to %func {
-       transform.apply_patterns.vector.step_to_arith
+       transform.apply_patterns.vector.step_to_arith_constant
      } : !transform.any_op
      transform.yield
   }
@@ -482,7 +482,7 @@ module attributes {transform.with_named_sequence} {
     %func = transform.structured.match ops{["func.func"]} in %arg1
        : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
-       transform.apply_patterns.vector.step_to_arith
+       transform.apply_patterns.vector.step_to_arith_constant
     } : !transform.any_op
     transform.yield
   }
@@ -518,7 +518,7 @@ module attributes {transform.with_named_sequence} {
     %func = transform.structured.match ops{["func.func"]} in %arg1
        : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
-      transform.apply_patterns.vector.step_to_arith
+      transform.apply_patterns.vector.step_to_arith_constant
     } : !transform.any_op
     transform.yield
   }
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index 2cf7804264ef8c..b8942cc4d3bff5 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -35,7 +35,7 @@ module attributes {transform.with_named_sequence} {
     %func = transform.structured.match ops{["func.func"]} in %arg1
        : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
-       transform.apply_patterns.vector.step_to_arith
+       transform.apply_patterns.vector.step_to_arith_constant
     } : !transform.any_op
     transform.yield
   }
@@ -180,7 +180,7 @@ module attributes {transform.with_named_sequence} {
     %func = transform.structured.match ops{["func.func"]} in %arg1
        : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
-       transform.apply_patterns.vector.step_to_arith
+       transform.apply_patterns.vector.step_to_arith_constant
     } : !transform.any_op
     transform.yield
   }
@@ -240,7 +240,7 @@ module attributes {transform.with_named_sequence} {
      %func = transform.structured.match ops{["func.func"]} in %arg1
        : (!transform.any_op) -> !transform.any_op
      transform.apply_patterns to %func {
-       transform.apply_patterns.vector.step_to_arith
+       transform.apply_patterns.vector.step_to_arith_constant
      } : !transform.any_op
      transform.yield
    }
@@ -325,7 +325,7 @@ module attributes {transform.with_named_sequence} {
     %func = transform.structured.match ops{["func.func"]} in %arg0
        : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
-      transform.apply_patterns.vector.step_to_arith
+      transform.apply_patterns.vector.step_to_arith_constant
     } : !transform.any_op
     transform.yield
   }
@@ -382,7 +382,7 @@ module attributes {transform.with_named_sequence} {
     %func = transform.structured.match ops{["func.func"]} in %arg2
        : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
-      transform.apply_patterns.vector.step_to_arith
+      transform.apply_patterns.vector.step_to_arith_constant
     } : !transform.any_op
     transform.yield
   }
@@ -434,7 +434,7 @@ module attributes {transform.with_named_sequence} {
     %func = transform.structured.match ops{["func.func"]} in %arg2
        : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
-      transform.apply_patterns.vector.step_to_arith
+      transform.apply_patterns.vector.step_to_arith_constant
     } : !transform.any_op
     transform.yield
   }
@@ -499,7 +499,7 @@ module attributes {transform.with_named_sequence} {
      %func = transform.structured.match ops{["func.func"]} in %arg1
        : (!transform.any_op) -> !transform.any_op
      transform.apply_patterns to %func {
-      transform.apply_patterns.vector.step_to_arith
+      transform.apply_patterns.vector.step_to_arith_constant
      } : !transform.any_op
      transform.yield
    }
@@ -548,7 +548,7 @@ module attributes {transform.with_named_sequence} {
      %func = transform.structured.match ops{["func.func"]} in %arg1
        : (!transform.any_op) -> !transform.any_op
      transform.apply_patterns to %func {
-      transform.apply_patterns.vector.step_to_arith
+      transform.apply_patterns.vector.step_to_arith_constant
      } : !transform.any_op
      transform.yield
    }
@@ -644,7 +644,7 @@ module attributes {transform.with_named_sequence} {
      %func = transform.structured.match ops{["func.func"]} in %arg1
        : (!transform.any_op) -> !transform.any_op
      transform.apply_patterns to %func {
-      transform.apply_patterns.vector.step_to_arith
+      transform.apply_patterns.vector.step_to_arith_constant
      } : !transform.any_op
      transform.yield
    }
@@ -691,7 +691,7 @@ module attributes {transform.with_named_sequence} {
      %func = transform.structured.match ops{["func.func"]} in %arg1
        : (!transform.any_op) -> !transform.any_op
      transform.apply_patterns to %func {
-      transform.apply_patterns.vector.step_to_arith
+      transform.apply_patterns.vector.step_to_arith_constant
      } : !transform.any_op
      transform.yield
    }
@@ -738,7 +738,7 @@ module attributes {transform.with_named_sequence} {
      %func = transform.structured.match ops{["func.func"]} in %arg1
        : (!transform.any_op) -> !transform.any_op
      transform.apply_patterns to %func {
-      transform.apply_patterns.vector.step_to_arith
+      transform.apply_patterns.vector.step_to_arith_constant
      } : !transform.any_op
      transform.yield
    }
@@ -784,7 +784,7 @@ module attributes {transform.with_named_sequence} {
      %func = transform.structured.match ops{["func.func"]} in %arg1
        : (!transform.any_op) -> !transform.any_op
      transform.apply_patterns to %func {
-      transform.apply_patterns.vector.step_to_arith
+      transform.apply_patterns.vector.step_to_arith_constant
      } : !transform.any_op
      transform.yield
    }
@@ -863,7 +863,7 @@ module attributes {transform.with_named_sequence} {
      %func = transform.structured.match ops{["func.func"]} in %arg1
        : (!transform.any_op) -> !transform.any_op
      transform.apply_patterns to %func {
-      transform.apply_patterns.vector.step_to_arith
+      transform.apply_patterns.vector.step_to_arith_constant
      } : !transform.any_op
      transform.yield
    }

>From b75c42834e6f9db5e98ac3779ab943d5abf7e2ab Mon Sep 17 00:00:00 2001
From: Manupa Karunaratne <manupa.karunaratne at amd.com>
Date: Mon, 28 Oct 2024 15:20:00 +0000
Subject: [PATCH 3/4] * pattern name update

---
 mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 80e4186664a640..3d643c96b45008 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -237,7 +237,7 @@ void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
 
 /// Populate the pattern set with the following patterns:
 ///
-/// [StepToArithOps]
+/// [StepToArithConstantOp]
 /// Convert vector.step op into arith ops if not using scalable vectors
 void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
                                         PatternBenefit benefit = 1);

>From 3da2b2618ad8280d5ef39e587cd99182f886a449 Mon Sep 17 00:00:00 2001
From: Manupa Karunaratne <manupa.karunaratne at amd.com>
Date: Tue, 29 Oct 2024 12:06:45 +0000
Subject: [PATCH 4/4] * rename rewriter

---
 mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp
index f0601e07756b92..ee5568aefda27b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp
@@ -23,7 +23,7 @@ using namespace mlir::vector;
 
 namespace {
 
-struct StepToArithConstantOp final : OpRewritePattern<vector::StepOp> {
+struct StepToArithConstantOpRewrite final : OpRewritePattern<vector::StepOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::StepOp stepOp,
@@ -45,5 +45,5 @@ struct StepToArithConstantOp final : OpRewritePattern<vector::StepOp> {
 
 void mlir::vector::populateVectorStepLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<StepToArithConstantOp>(patterns.getContext(), benefit);
+  patterns.add<StepToArithConstantOpRewrite>(patterns.getContext(), benefit);
 }



More information about the Mlir-commits mailing list