[Mlir-commits] [mlir] [MLIR][Vector] Add Lowering for vector.step (PR #113655)
Manupa Karunaratne
llvmlistbot at llvm.org
Thu Oct 31 03:56:30 PDT 2024
https://github.com/manupak updated https://github.com/llvm/llvm-project/pull/113655
>From c7b9fd20ec40617243f2f8187ec1e4f635d7e18d Mon Sep 17 00:00:00 2001
From: Manupa Karunaratne <manupa.karunaratne at amd.com>
Date: Fri, 25 Oct 2024 05:45:39 +0000
Subject: [PATCH] [MLIR][Vector] Add Lowering for vector.step
Currently, the lowering for vector.step lives
under a folder. This is not ideal if we want
to do transformation on it and defer the
materizaliztion of the constants much later.
This commits adds a rewrite pattern + transform
op to do this instead. Thus enabling more control
on the lowering.
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 1 -
.../Vector/Transforms/LoweringPatterns.h | 7 +++
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 10 +++-
.../TransformOps/LinalgTransformOps.cpp | 1 +
.../Transforms/SparseVectorization.cpp | 2 +
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 14 ------
.../Dialect/Vector/Transforms/CMakeLists.txt | 1 +
.../Vector/Transforms/LowerVectorStep.cpp | 49 +++++++++++++++++++
.../VectorToLLVM/vector-to-llvm.mlir | 10 ++++
.../Linalg/vectorization-scalable.mlir | 21 ++++++--
mlir/test/Dialect/Vector/canonicalize.mlir | 9 ----
11 files changed, 95 insertions(+), 30 deletions(-)
create mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 474f4ccf4891de..b54a8b7fe8680d 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2946,7 +2946,6 @@ def Vector_StepOp : Vector_Op<"step", [Pure]> {
%1 = vector.step : vector<[4]xindex> ; [0, 1, .., <vscale * 4 - 1>]
```
}];
- let hasFolder = 1;
let results = (outs VectorOfRankAndType<[1], [Index]>:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 1976b8399c7f9c..3d643c96b45008 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -235,6 +235,13 @@ void populateVectorTransferPermutationMapLoweringPatterns(
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Populate the pattern set with the following patterns:
+///
+/// [StepToArithConstantOp]
+/// Convert vector.step op into arith ops if not using scalable vectors
+void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
/// Populate the pattern set with the following patterns:
///
/// [FlattenGather]
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 984af50a7b0a51..58ca84c8d7bca6 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1865,12 +1865,17 @@ struct VectorFromElementsLowering
};
/// Conversion pattern for vector.step.
-struct VectorStepOpLowering : public ConvertOpToLLVMPattern<vector::StepOp> {
+struct VectorScalableStepOpLowering
+ : public ConvertOpToLLVMPattern<vector::StepOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ auto resultType = cast<VectorType>(stepOp.getType());
+ if (!resultType.isScalable()) {
+ return failure();
+ }
Type llvmType = typeConverter->convertType(stepOp.getType());
rewriter.replaceOpWithNewOp<LLVM::StepVectorOp>(stepOp, llvmType);
return success();
@@ -1886,6 +1891,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.add<VectorFMAOpNDRewritePattern>(ctx);
populateVectorInsertExtractStridedSliceTransforms(patterns);
+ populateVectorStepLoweringPatterns(patterns);
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
@@ -1903,7 +1909,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
MaskedReductionOpConversion, VectorInterleaveOpLowering,
VectorDeinterleaveOpLowering, VectorFromElementsLowering,
- VectorStepOpLowering>(converter);
+ VectorScalableStepOpLowering>(converter);
// Transfer ops with rank > 1 are handled by VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 3d3f0a93a3829b..469c85a982c51a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3484,6 +3484,7 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
if (getVectorizePadding())
linalg::populatePadOpVectorizationPatterns(patterns);
+ vector::populateVectorStepLoweringPatterns(patterns);
TrackingListener listener(state, *this);
GreedyRewriteConfig config;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index d1c95dabd88a5e..b2eca539194a87 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -27,6 +27,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/IR/Matchers.h"
using namespace mlir;
@@ -664,6 +665,7 @@ void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
bool enableVLAVectorization,
bool enableSIMDIndex32) {
assert(vectorLength > 0);
+ vector::populateVectorStepLoweringPatterns(patterns);
patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
enableVLAVectorization, enableSIMDIndex32);
patterns.add<ReducChainRewriter<vector::InsertElementOp>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d71a236f62f454..daabf0deebe4e4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6423,20 +6423,6 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
return SplatElementsAttr::get(getType(), {constOperand});
}
-//===----------------------------------------------------------------------===//
-// StepOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult StepOp::fold(FoldAdaptor adaptor) {
- auto resultType = cast<VectorType>(getType());
- if (resultType.isScalable())
- return nullptr;
- SmallVector<APInt> indices;
- for (unsigned i = 0; i < resultType.getNumElements(); i++)
- indices.push_back(APInt(/*width=*/64, i));
- return DenseElementsAttr::get(resultType, indices);
-}
-
//===----------------------------------------------------------------------===//
// WarpExecuteOnLane0Op
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index b7e8724c3c2582..9a3bd5d4593d63 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorMultiReduction.cpp
LowerVectorScan.cpp
LowerVectorShapeCast.cpp
+ LowerVectorStep.cpp
LowerVectorTransfer.cpp
LowerVectorTranspose.cpp
SubsetOpInterfaceImpl.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp
new file mode 100644
index 00000000000000..ee5568aefda27b
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp
@@ -0,0 +1,49 @@
+//===- LowerVectorStep.cpp - Lower 'vector.step' operation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.step' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/IR/PatternMatch.h"
+
+#define DEBUG_TYPE "vector-step-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+struct StepToArithConstantOpRewrite final : OpRewritePattern<vector::StepOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::StepOp stepOp,
+ PatternRewriter &rewriter) const override {
+ auto resultType = cast<VectorType>(stepOp.getType());
+ if (resultType.isScalable()) {
+ return failure();
+ }
+ int64_t elementCount = resultType.getNumElements();
+ SmallVector<APInt> indices =
+ llvm::map_to_vector(llvm::seq(elementCount),
+ [](int64_t i) { return APInt(/*width=*/64, i); });
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+ stepOp, DenseElementsAttr::get(resultType, indices));
+ return success();
+ }
+};
+} // namespace
+
+void mlir::vector::populateVectorStepLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<StepToArithConstantOpRewrite>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index eb6da71b063273..c1de24fd0403ce 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -3448,3 +3448,13 @@ func.func @vector_step_scalable() -> vector<[4]xindex> {
%0 = vector.step : vector<[4]xindex>
return %0 : vector<[4]xindex>
}
+
+// -----
+
+// CHECK-LABEL: @vector_step
+// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK: return %[[CST]] : vector<4xindex>
+func.func @vector_step() -> vector<4xindex> {
+ %0 = vector.step : vector<4xindex>
+ return %0 : vector<4xindex>
+}
diff --git a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
index c3a30e3ee209e8..ca11bc328927b0 100644
--- a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
@@ -167,10 +167,23 @@ func.func @vectorize_linalg_index(%arg0: tensor<3x3x?xf32>, %arg1: tensor<1x1x?x
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[DST_DIM2:.*]] = tensor.dim %[[DST]], %[[C2]] : tensor<1x1x?xf32>
// CHECK: %[[MASK:.*]] = vector.create_mask %[[C1]], %[[C1]], %[[DST_DIM2]] : vector<1x1x[4]xi1>
-// CHECK: %[[INDEX_VEC:.*]] = vector.step : vector<[4]xindex>
-// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]][%c0, %c0, %2], %cst {in_bounds = [true, true, true]} : tensor<3x3x?xf32>, vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32>
-// CHECK: %[[OUT:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<1x1x?xf32> } : vector<1x1x[4]xi1> -> tensor<1x1x?xf32>
-// CHECK: return %[[OUT]] : tensor<1x1x?xf32>
+
+// CHECK-DAG: %[[STEP1:.+]] = vector.step : vector<1xindex>
+// CHECK-DAG: %[[STEP1B:.+]] = vector.broadcast %[[STEP1]] : vector<1xindex> to vector<1x1x[4]xindex>
+// CHECK-DAG: %[[STEP1B_CAST:.+]] = vector.shape_cast %[[STEP1B]] : vector<1x1x[4]xindex> to vector<[4]xindex>
+// CHECK-DAG: %[[STEP1B_ELEMENT:.+]] = vector.extractelement %[[STEP1B_CAST]][%c0_i32 : i32] : vector<[4]xindex>
+
+// CHECK-DAG: %[[STEP2:.+]] = vector.step : vector<1xindex>
+// CHECK-DAG: %[[STEP2B:.+]] = vector.broadcast %[[STEP2]] : vector<1xindex> to vector<1x1x[4]xindex>
+// CHECK-DAG: %[[STEP2B_CAST:.+]] = vector.shape_cast %[[STEP2B]] : vector<1x1x[4]xindex> to vector<[4]xindex>
+// CHECK-DAG: %[[STEP2B_ELEMENT:.+]] = vector.extractelement %[[STEP2B_CAST]][%c0_i32 : i32] : vector<[4]xindex>
+
+// CHECK-DAG: %[[STEP_SCALABLE:.+]] = vector.step : vector<[4]xindex>
+// CHECK-DAG: %[[STEP_SCALABLE_ELEMENT:.+]] = vector.extractelement %[[STEP_SCALABLE]][%c0_i32 : i32] : vector<[4]xindex>
+
+// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]][%[[STEP1B_ELEMENT]], %[[STEP2B_ELEMENT]], %[[STEP_SCALABLE_ELEMENT]]], %cst {in_bounds = [true, true, true]} : tensor<3x3x?xf32>, vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32>
+// CHECK: %[[OUT:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<1x1x?xf32> } : vector<1x1x[4]xi1> -> tensor<1x1x?xf32>
+// CHECK: return %[[OUT]] : tensor<1x1x?xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 6d6bc199e601c0..3f079c486e5ca4 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2722,15 +2722,6 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
}
-// -----
-
-// CHECK-LABEL: @fold_vector_step_to_constant
-// CHECK: %[[CONSTANT:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK: return %[[CONSTANT]] : vector<4xindex>
-func.func @fold_vector_step_to_constant() -> vector<4xindex> {
- %0 = vector.step : vector<4xindex>
- return %0 : vector<4xindex>
-}
// -----
More information about the Mlir-commits
mailing list