[Mlir-commits] [mlir] a7a5641 - [mlir][vector] Fix bug in `TransferWriteNonPermutationLowering`
Matthias Springer
llvmlistbot at llvm.org
Mon Jul 10 08:25:46 PDT 2023
Author: Matthias Springer
Date: 2023-07-10T17:21:03+02:00
New Revision: a7a5641bdcfa92e95771ccfcc0a14d42611ac2f8
URL: https://github.com/llvm/llvm-project/commit/a7a5641bdcfa92e95771ccfcc0a14d42611ac2f8
DIFF: https://github.com/llvm/llvm-project/commit/a7a5641bdcfa92e95771ccfcc0a14d42611ac2f8.diff
LOG: [mlir][vector] Fix bug in `TransferWriteNonPermutationLowering`
This pattern expands the rank of the vector. However, the rank of the mask was not expanded.
Differential Revision: https://reviews.llvm.org/D154849
Added:
mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
Modified:
mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 4f68526ac401ea..af591730ad963e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -46,6 +46,21 @@ static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
}
+/// Extend the rank of a vector Value by `addedRanks` by adding inner unit
+/// dimensions.
+static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec,
+ int64_t addedRank) {
+ Value broadcasted = extendVectorRank(builder, loc, vec, addedRank);
+ SmallVector<int64_t> permutation;
+ for (int64_t i = addedRank,
+ e = broadcasted.getType().cast<VectorType>().getRank();
+ i < e; ++i)
+ permutation.push_back(i);
+ for (int64_t i = 0; i < addedRank; ++i)
+ permutation.push_back(i);
+ return builder.create<vector::TransposeOp>(loc, broadcasted, permutation);
+}
+
//===----------------------------------------------------------------------===//
// populateVectorTransferPermutationMapLoweringPatterns
//===----------------------------------------------------------------------===//
@@ -246,9 +261,14 @@ struct TransferWriteNonPermutationLowering
missingInnerDim.push_back(i);
exprs.push_back(rewriter.getAffineDimExpr(i));
}
- // Add unit dims at the beginning of the shape.
+ // Vector: add unit dims at the beginning of the shape.
Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(),
missingInnerDim.size());
+ // Mask: add unit dims at the end of the shape.
+ Value newMask;
+ if (op.getMask())
+ newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(),
+ missingInnerDim.size());
exprs.append(map.getResults().begin(), map.getResults().end());
AffineMap newMap =
AffineMap::get(map.getNumDims(), 0, exprs, op.getContext());
@@ -263,7 +283,7 @@ struct TransferWriteNonPermutationLowering
}
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
- op.getMask(), newInBoundsAttr);
+ newMask, newInBoundsAttr);
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
new file mode 100644
index 00000000000000..6ea53aa3f41b07
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func @lower_permutation_with_mask(
+// CHECK: %[[vec:.*]] = arith.constant dense<-2.000000e+00> : vector<7x1xf32>
+// CHECK: %[[mask:.*]] = arith.constant dense<[true, false, true, false, true, true, true]> : vector<7xi1>
+// CHECK: %[[b:.*]] = vector.broadcast %[[mask]] : vector<7xi1> to vector<1x7xi1>
+// CHECK: %[[tp:.*]] = vector.transpose %[[b]], [1, 0] : vector<1x7xi1> to vector<7x1xi1>
+// CHECK: vector.transfer_write %[[vec]], %{{.*}}[%{{.*}}, %{{.*}}], %[[tp]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32>
+func.func @lower_permutation_with_mask(%A : memref<?x?xf32>, %base1 : index,
+ %base2 : index) {
+ %fn1 = arith.constant -2.0 : f32
+ %vf0 = vector.splat %fn1 : vector<7xf32>
+ %mask = arith.constant dense<[1, 0, 1, 0, 1, 1, 1]> : vector<7xi1>
+ vector.transfer_write %vf0, %A[%base1, %base2], %mask
+ {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [false]}
+ : vector<7xf32>, memref<?x?xf32>
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.transfer_permutation_patterns
+ } : !transform.any_op
+}
More information about the Mlir-commits
mailing list