[Mlir-commits] [mlir] 296d5cb - [mlir][BuiltinTypes] Return VectorType from VectorType::Builder conversion operator
Benjamin Maxwell
llvmlistbot at llvm.org
Wed Aug 30 06:47:40 PDT 2023
Author: Benjamin Maxwell
Date: 2023-08-30T13:47:06Z
New Revision: 296d5cb60c20fe314babcd93fb5df5ecc24ae987
URL: https://github.com/llvm/llvm-project/commit/296d5cb60c20fe314babcd93fb5df5ecc24ae987
DIFF: https://github.com/llvm/llvm-project/commit/296d5cb60c20fe314babcd93fb5df5ecc24ae987.diff
LOG: [mlir][BuiltinTypes] Return VectorType from VectorType::Builder conversion operator
0-D vectors are now supported, so the special case of returning the just
the element type can now be removed.
A few callers that relied on the old behaviour have been updated.
Reviewed By: awarzynski, nicolasvasilache
Differential Revision: https://reviews.llvm.org/D159122
Added:
Modified:
mlir/include/mlir/IR/BuiltinTypes.h
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index ce68fc2673dcaf..f0b19fe543a5bf 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -357,12 +357,7 @@ class VectorType::Builder {
return *this;
}
- /// In the particular case where the vector has a single dimension that we
- /// drop, return the scalar element type.
- // TODO: unify once we have a VectorType that supports 0-D.
- operator Type() {
- if (shape.empty())
- return elementType;
+ operator VectorType() {
return VectorType::get(shape, elementType, scalableDims);
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 88bda3931a5a11..af539d2c3795a0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2216,7 +2216,7 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
return failure();
if (mask.size() != 1)
return failure();
- Type resType = VectorType::Builder(v1VectorType).setShape({1});
+ VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
if (llvm::cast<IntegerAttr>(mask[0]).getInt() == 0)
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
shuffleOp.getV1());
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 66ac5ffef3e3ed..1b3d617a79edb7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -89,21 +89,20 @@ static Value reshapeLoad(Location loc, Value val, VectorType type,
PatternRewriter &rewriter) {
if (index == -1)
return val;
- Type lowType = VectorType::Builder(type).dropDim(0);
+ Type lowType = type.getRank() > 1 ? VectorType::Builder(type).dropDim(0)
+ : type.getElementType();
// At extraction dimension?
if (index == 0)
return rewriter.create<vector::ExtractOp>(loc, lowType, val, pos);
// Unroll leading dimensions.
VectorType vType = cast<VectorType>(lowType);
- Type resType = VectorType::Builder(type).dropDim(index);
- auto resVectorType = cast<VectorType>(resType);
+ VectorType resType = VectorType::Builder(type).dropDim(index);
Value result = rewriter.create<arith::ConstantOp>(
- loc, resVectorType, rewriter.getZeroAttr(resVectorType));
- for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
+ loc, resType, rewriter.getZeroAttr(resType));
+ for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, d);
Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
- result =
- rewriter.create<vector::InsertOp>(loc, resVectorType, load, result, d);
+ result = rewriter.create<vector::InsertOp>(loc, resType, load, result, d);
}
return result;
}
@@ -120,13 +119,13 @@ static Value reshapeStore(Location loc, Value val, Value result,
if (index == 0)
return rewriter.create<vector::InsertOp>(loc, type, val, result, pos);
// Unroll leading dimensions.
- Type lowType = VectorType::Builder(type).dropDim(0);
- VectorType vType = cast<VectorType>(lowType);
- Type insType = VectorType::Builder(vType).dropDim(0);
+ VectorType lowType = VectorType::Builder(type).dropDim(0);
+ Type insType = lowType.getRank() > 1 ? VectorType::Builder(lowType).dropDim(0)
+ : lowType.getElementType();
for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
- Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, d);
+ Value ext = rewriter.create<vector::ExtractOp>(loc, lowType, result, d);
Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, d);
- Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
+ Value sto = reshapeStore(loc, ins, ext, lowType, index - 1, pos, rewriter);
result = rewriter.create<vector::InsertOp>(loc, type, sto, result, d);
}
return result;
More information about the Mlir-commits
mailing list