[Mlir-commits] [mlir] b4bdcea - [mlir][arith] Define mului_extended op
Jakub Kuderski
llvmlistbot at llvm.org
Fri Dec 9 14:38:37 PST 2022
Author: Jakub Kuderski
Date: 2022-12-09T17:37:06-05:00
New Revision: b4bdcea2148f8885ac4a109e1cd3c9e67a46c161
URL: https://github.com/llvm/llvm-project/commit/b4bdcea2148f8885ac4a109e1cd3c9e67a46c161
DIFF: https://github.com/llvm/llvm-project/commit/b4bdcea2148f8885ac4a109e1cd3c9e67a46c161.diff
LOG: [mlir][arith] Define mului_extended op
Add conversion to the SPIR-V and LLVM dialects.
This was originally proposed in:
https://discourse.llvm.org/t/rfc-arith-add-extended-multiplication-ops/66869.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D139688
Added:
Modified:
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
mlir/test/Dialect/Arith/canonicalize.mlir
mlir/test/Dialect/Arith/ops.mlir
mlir/tools/mlir-tblgen/RewriterGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 0c54acac59a36..f02a2f89a987d 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -291,6 +291,49 @@ def Arith_MulIOp : Arith_TotalIntBinaryOp<"muli", [Commutative]> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// MulUIExtendedOp
+//===----------------------------------------------------------------------===//
+
+def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative,
+ AllTypesMatch<["lhs", "rhs", "low", "high"]>]> {
+ let summary = [{
+ extended unsigned integer multiplication operation
+ }];
+
+ let description = [{
+ Performs (2*N)-bit multiplication on zero-extended operands. Returns two
+ N-bit results: the low and the high halves of the product. The low half has
+ the same value as the result of regular multiplication `arith.muli` with
+ the same operands.
+
+ Example:
+
+ ```mlir
+ // Scalar multiplication.
+ %low, %high = arith.mului_extended %a, %b : i32
+
+ // Vector element-wise multiplication.
+ %c:2 = arith.mului_extended %d, %e : vector<4xi32>
+
+ // Tensor element-wise multiplication.
+ %x:2 = arith.mului_extended %y, %z : tensor<4x?xi8>
+ ```
+ }];
+
+ let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
+ let results = (outs SignlessIntegerLike:$low, SignlessIntegerLike:$high);
+
+ let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
+
+ let hasFolder = 1;
+ let hasCanonicalizer = 1;
+
+ let extraClassDeclaration = [{
+ ::llvm::Optional<::llvm::SmallVector<int64_t, 4>> getShapeForUnroll();
+ }];
+}
+
//===----------------------------------------------------------------------===//
// DivUIOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 0289bea88b504..9483727f68508 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -142,6 +142,15 @@ struct AddUIExtendedOpLowering
ConversionPatternRewriter &rewriter) const override;
};
+struct MulUIExtendedOpLowering
+ : public ConvertOpToLLVMPattern<arith::MulUIExtendedOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::MulUIExtendedOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
@@ -261,6 +270,67 @@ LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
"ND vector types are not supported yet");
}
+//===----------------------------------------------------------------------===//
+// MulUIExtendedOpLowering
+//===----------------------------------------------------------------------===//
+
+LogicalResult MulUIExtendedOpLowering::matchAndRewrite(
+ arith::MulUIExtendedOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Type resultType = adaptor.getLhs().getType();
+
+ if (!LLVM::isCompatibleType(resultType))
+ return failure();
+
+ Location loc = op.getLoc();
+
+ // Handle the scalar and 1D vector cases. Because LLVM does not have a
+ // matching extended multiplication intrinsic, perform regular multiplication
+ // on operands zero-extended to i(2*N) bits, and truncate the results back to
+ // iN types.
+ if (!resultType.isa<LLVM::LLVMArrayType>()) {
+ Type wideType;
+ // Shift amount necessary to extract the high bits from widened result.
+ Attribute shiftValAttr;
+
+ if (auto intTy = resultType.dyn_cast<IntegerType>()) {
+ unsigned resultBitwidth = intTy.getWidth();
+ wideType = rewriter.getIntegerType(resultBitwidth * 2);
+ shiftValAttr = rewriter.getIntegerAttr(wideType, resultBitwidth);
+ } else {
+ auto vecTy = resultType.cast<VectorType>();
+ unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
+ wideType = VectorType::get(vecTy.getShape(),
+ rewriter.getIntegerType(resultBitwidth * 2));
+ shiftValAttr = SplatElementsAttr::get(
+ wideType, APInt(resultBitwidth * 2, resultBitwidth));
+ }
+ assert(LLVM::isCompatibleType(wideType) &&
+ "LLVM dialect should support all signless integer types");
+
+ Value lhsExt =
+ rewriter.create<LLVM::ZExtOp>(loc, wideType, adaptor.getLhs());
+ Value rhsExt =
+ rewriter.create<LLVM::ZExtOp>(loc, wideType, adaptor.getRhs());
+ Value mulExt = rewriter.create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt);
+
+ // Split the 2*N-bit wide result into two N-bit values.
+ Value low = rewriter.create<LLVM::TruncOp>(loc, resultType, mulExt);
+ Value shiftVal = rewriter.create<LLVM::ConstantOp>(loc, shiftValAttr);
+ Value highExt = rewriter.create<LLVM::LShrOp>(loc, mulExt, shiftVal);
+ Value high = rewriter.create<LLVM::TruncOp>(loc, resultType, highExt);
+
+ rewriter.replaceOp(op, {low, high});
+ return success();
+ }
+
+ if (!resultType.isa<VectorType>())
+ return rewriter.notifyMatchFailure(op, "expected vector result type");
+
+ return rewriter.notifyMatchFailure(op,
+ "ND vector types are not supported yet");
+}
+
//===----------------------------------------------------------------------===//
// CmpIOpLowering
//===----------------------------------------------------------------------===//
@@ -397,6 +467,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
MinUIOpLowering,
MulFOpLowering,
MulIOpLowering,
+ MulUIExtendedOpLowering,
NegFOpLowering,
OrIOpLowering,
RemFOpLowering,
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index a127dd8f4e8a6..03e877f56578d 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -223,6 +223,16 @@ class AddUIExtendedOpPattern final
ConversionPatternRewriter &rewriter) const override;
};
+/// Converts arith.mului_extended to spirv.UMulExtended.
+class MulUIExtendedOpPattern final
+ : public OpConversionPattern<arith::MulUIExtendedOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(arith::MulUIExtendedOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
/// Converts arith.select to spirv.Select.
class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
public:
@@ -944,6 +954,26 @@ LogicalResult AddUIExtendedOpPattern::matchAndRewrite(
return success();
}
+//===----------------------------------------------------------------------===//
+// MulUIExtendedOpPattern
+//===----------------------------------------------------------------------===//
+
+LogicalResult MulUIExtendedOpPattern::matchAndRewrite(
+ arith::MulUIExtendedOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Location loc = op->getLoc();
+ Value result = rewriter.create<spirv::UMulExtendedOp>(loc, adaptor.getLhs(),
+ adaptor.getRhs());
+
+ Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result,
+ llvm::makeArrayRef(0));
+ Value high = rewriter.create<spirv::CompositeExtractOp>(
+ loc, result, llvm::makeArrayRef(1));
+
+ rewriter.replaceOp(op, {low, high});
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// SelectOpPattern
//===----------------------------------------------------------------------===//
@@ -1040,7 +1070,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
CmpIOpBooleanPattern, CmpIOpPattern,
CmpFOpNanNonePattern, CmpFOpPattern,
- AddUIExtendedOpPattern, SelectOpPattern,
+ AddUIExtendedOpPattern, MulUIExtendedOpPattern, SelectOpPattern,
MinMaxFOpPattern<arith::MaxFOp, spirv::GLFMaxOp>,
MinMaxFOpPattern<arith::MinFOp, spirv::GLFMinOp>,
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index a12c1fe5c6326..7637369c4e552 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -111,6 +111,17 @@ def SubISubILHSRHSLHS :
Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y), $x),
(Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y)>;
+//===----------------------------------------------------------------------===//
+// MulUIExtendedOp
+//===----------------------------------------------------------------------===//
+
+// mului_extended(x, y) -> [muli(x, y), x], when the `high` result is unused.
+// Since the `high` result it not used, any replacement value will do.
+def MulUIExtendedToMulI :
+ Pattern<(Arith_MulUIExtendedOp:$res $x, $y),
+ [(Arith_MulIOp $x, $y), (replaceWithValue $x)],
+ [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
+
//===----------------------------------------------------------------------===//
// XOrIOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 29bce4def8745..fde37d6bd69d4 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -237,7 +237,7 @@ static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
LogicalResult
arith::AddUIExtendedOp::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
- auto overflowTy = getOverflow().getType();
+ Type overflowTy = getOverflow().getType();
// addui_extended(x, 0) -> x, false
if (matchPattern(getRhs(), m_Zero())) {
auto overflowZero = APInt::getZero(1);
@@ -350,6 +350,61 @@ OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
operands, [](const APInt &a, const APInt &b) { return a * b; });
}
+//===----------------------------------------------------------------------===//
+// MulUIExtendedOp
+//===----------------------------------------------------------------------===//
+
+Optional<SmallVector<int64_t, 4>> arith::MulUIExtendedOp::getShapeForUnroll() {
+ if (auto vt = getType(0).dyn_cast<VectorType>())
+ return llvm::to_vector<4>(vt.getShape());
+ return std::nullopt;
+}
+
+LogicalResult
+arith::MulUIExtendedOp::fold(ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ // mului_extended(x, 0) -> 0, 0
+ if (matchPattern(getRhs(), m_Zero())) {
+ Attribute zero = operands[1];
+ results.push_back(zero);
+ results.push_back(zero);
+ return success();
+ }
+
+ // mului_extended(x, 1) -> x, 0
+ if (matchPattern(getRhs(), m_One())) {
+ Builder builder(getContext());
+ Attribute zero = builder.getZeroAttr(getLhs().getType());
+ results.push_back(getLhs());
+ results.push_back(zero);
+ return success();
+ }
+
+ // mului_extended(cst_a, cst_b) -> cst_low, cst_high
+ if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
+ operands, [](const APInt &a, const APInt &b) { return a * b; })) {
+ // Invoke the constant fold helper again to calculate the 'high' result.
+ Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
+ operands, [](const APInt &a, const APInt &b) {
+ unsigned bitWidth = a.getBitWidth();
+ APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
+ return fullProduct.extractBits(bitWidth, bitWidth);
+ });
+ assert(highAttr && "Unexpected constant-folding failure");
+
+ results.push_back(lowAttr);
+ results.push_back(highAttr);
+ return success();
+ }
+
+ return failure();
+}
+
+void arith::MulUIExtendedOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ patterns.add<MulUIExtendedToMulI>(context);
+}
+
//===----------------------------------------------------------------------===//
// DivUIOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index cf207c283b0d4..8abf8130744a5 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -379,6 +379,38 @@ func.func @addui_extended_vector1d(%arg0: vector<3xi16>, %arg1: vector<3xi16>) -
// -----
+// CHECK-LABEL: @mului_extended_scalar
+// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> (i32, i32)
+func.func @mului_extended_scalar(%arg0: i32, %arg1: i32) -> (i32, i32) {
+ // CHECK-NEXT: [[LHS:%.+]] = llvm.zext [[ARG0]] : i32 to i64
+ // CHECK-NEXT: [[RHS:%.+]] = llvm.zext [[ARG1]] : i32 to i64
+ // CHECK-NEXT: [[MUL:%.+]] = llvm.mul [[LHS]], [[RHS]] : i64
+ // CHECK-NEXT: [[LOW:%.+]] = llvm.trunc [[MUL]] : i64 to i32
+ // CHECK-NEXT: [[C32:%.+]] = llvm.mlir.constant(32 : i64) : i64
+ // CHECK-NEXT: [[SHL:%.+]] = llvm.lshr [[MUL]], [[C32]] : i64
+ // CHECK-NEXT: [[HIGH:%.+]] = llvm.trunc [[SHL]] : i64 to i32
+ %low, %high = arith.mului_extended %arg0, %arg1 : i32
+ // CHECK-NEXT: return [[LOW]], [[HIGH]] : i32, i32
+ return %low, %high : i32, i32
+}
+
+// CHECK-LABEL: @mului_extended_vector1d
+// CHECK-SAME: ([[ARG0:%.+]]: vector<3xi64>, [[ARG1:%.+]]: vector<3xi64>) -> (vector<3xi64>, vector<3xi64>)
+func.func @mului_extended_vector1d(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> (vector<3xi64>, vector<3xi64>) {
+ // CHECK-NEXT: [[LHS:%.+]] = llvm.zext [[ARG0]] : vector<3xi64> to vector<3xi128>
+ // CHECK-NEXT: [[RHS:%.+]] = llvm.zext [[ARG1]] : vector<3xi64> to vector<3xi128>
+ // CHECK-NEXT: [[MUL:%.+]] = llvm.mul [[LHS]], [[RHS]] : vector<3xi128>
+ // CHECK-NEXT: [[LOW:%.+]] = llvm.trunc [[MUL]] : vector<3xi128> to vector<3xi64>
+ // CHECK-NEXT: [[C64:%.+]] = llvm.mlir.constant(dense<64> : vector<3xi128>) : vector<3xi128>
+ // CHECK-NEXT: [[SHL:%.+]] = llvm.lshr [[MUL]], [[C64]] : vector<3xi128>
+ // CHECK-NEXT: [[HIGH:%.+]] = llvm.trunc [[SHL]] : vector<3xi128> to vector<3xi64>
+ %low, %high = arith.mului_extended %arg0, %arg1 : vector<3xi64>
+ // CHECK-NEXT: return [[LOW]], [[HIGH]] : vector<3xi64>, vector<3xi64>
+ return %low, %high : vector<3xi64>, vector<3xi64>
+}
+
+// -----
+
// CHECK-LABEL: func @cmpf_2dvector(
func.func @cmpf_2dvector(%arg0 : vector<4x3xf32>, %arg1 : vector<4x3xf32>) {
// CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 938bafa357cf3..c2f642e3fb264 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -99,6 +99,29 @@ func.func @int32_vector_addui_extended(%lhs: vector<4xi32>, %rhs: vector<4xi32>)
return %sum, %overflow : vector<4xi32>, vector<4xi1>
}
+// Check extended unsigned integer multiplication conversions.
+// CHECK-LABEL: @int32_scalar_mului_extended
+// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
+func.func @int32_scalar_mului_extended(%lhs: i32, %rhs: i32) -> (i32, i32) {
+ // CHECK-NEXT: %[[MUL:.+]] = spirv.UMulExtended %[[LHS]], %[[RHS]] : !spirv.struct<(i32, i32)>
+ // CHECK-DAG: %[[LOW:.+]] = spirv.CompositeExtract %[[MUL]][0 : i32] : !spirv.struct<(i32, i32)>
+ // CHECK-DAG: %[[HIGH:.+]] = spirv.CompositeExtract %[[MUL]][1 : i32] : !spirv.struct<(i32, i32)>
+ // CHECK-NEXT: return %[[LOW]], %[[HIGH]] : i32, i32
+ %low, %high = arith.mului_extended %lhs, %rhs: i32
+ return %low, %high : i32, i32
+}
+
+// CHECK-LABEL: @int32_vector_mului_extended
+// CHECK-SAME: (%[[LHS:.+]]: vector<4xi32>, %[[RHS:.+]]: vector<4xi32>)
+func.func @int32_vector_mului_extended(%lhs: vector<4xi32>, %rhs: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
+ // CHECK-NEXT: %[[MUL:.+]] = spirv.UMulExtended %[[LHS]], %[[RHS]] : !spirv.struct<(vector<4xi32>, vector<4xi32>)>
+ // CHECK-DAG: %[[LOW:.+]] = spirv.CompositeExtract %[[MUL]][0 : i32] : !spirv.struct<(vector<4xi32>, vector<4xi32>)>
+ // CHECK-DAG: %[[HIGH:.+]] = spirv.CompositeExtract %[[MUL]][1 : i32] : !spirv.struct<(vector<4xi32>, vector<4xi32>)>
+ // CHECK-NEXT: return %[[LOW]], %[[HIGH]] : vector<4xi32>, vector<4xi32>
+ %low, %high = arith.mului_extended %lhs, %rhs: vector<4xi32>
+ return %low, %high : vector<4xi32>, vector<4xi32>
+}
+
// Check float unary operation conversions.
// CHECK-LABEL: @float32_unary_scalar
func.func @float32_unary_scalar(%arg0: f32) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 6a1c0feb9025d..1280a2adece89 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -734,6 +734,104 @@ func.func @adduiExtendedConstantsSplatVector() -> (vector<4xi32>, vector<4xi1>)
return %sum, %overflow : vector<4xi32>, vector<4xi1>
}
+// CHECK-LABEL: @muluiExtendedZeroRhs
+// CHECK-NEXT: %[[zero:.+]] = arith.constant 0 : i32
+// CHECK-NEXT: return %[[zero]], %[[zero]]
+func.func @muluiExtendedZeroRhs(%arg0: i32) -> (i32, i32) {
+ %zero = arith.constant 0 : i32
+ %low, %high = arith.mului_extended %arg0, %zero: i32
+ return %low, %high : i32, i32
+}
+
+// CHECK-LABEL: @muluiExtendedZeroRhsSplat
+// CHECK-NEXT: %[[zero:.+]] = arith.constant dense<0> : vector<3xi32>
+// CHECK-NEXT: return %[[zero]], %[[zero]]
+func.func @muluiExtendedZeroRhsSplat(%arg0: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>) {
+ %zero = arith.constant dense<0> : vector<3xi32>
+ %low, %high = arith.mului_extended %arg0, %zero: vector<3xi32>
+ return %low, %high : vector<3xi32>, vector<3xi32>
+}
+
+// CHECK-LABEL: @muluiExtendedZeroLhs
+// CHECK-NEXT: %[[zero:.+]] = arith.constant 0 : i32
+// CHECK-NEXT: return %[[zero]], %[[zero]]
+func.func @muluiExtendedZeroLhs(%arg0: i32) -> (i32, i32) {
+ %zero = arith.constant 0 : i32
+ %low, %high = arith.mului_extended %zero, %arg0: i32
+ return %low, %high : i32, i32
+}
+
+// CHECK-LABEL: @muluiExtendedOneRhs
+// CHECK-SAME: (%[[ARG:.+]]: i32) -> (i32, i32)
+// CHECK-NEXT: %[[zero:.+]] = arith.constant 0 : i32
+// CHECK-NEXT: return %[[ARG]], %[[zero]]
+func.func @muluiExtendedOneRhs(%arg0: i32) -> (i32, i32) {
+ %zero = arith.constant 1 : i32
+ %low, %high = arith.mului_extended %arg0, %zero: i32
+ return %low, %high : i32, i32
+}
+
+// CHECK-LABEL: @muluiExtendedOneRhsSplat
+// CHECK-SAME: (%[[ARG:.+]]: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>)
+// CHECK-NEXT: %[[zero:.+]] = arith.constant dense<0> : vector<3xi32>
+// CHECK-NEXT: return %[[ARG]], %[[zero]]
+func.func @muluiExtendedOneRhsSplat(%arg0: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>) {
+ %zero = arith.constant dense<1> : vector<3xi32>
+ %low, %high = arith.mului_extended %arg0, %zero: vector<3xi32>
+ return %low, %high : vector<3xi32>, vector<3xi32>
+}
+
+// CHECK-LABEL: @muluiExtendedOneLhs
+// CHECK-SAME: (%[[ARG:.+]]: i32) -> (i32, i32)
+// CHECK-NEXT: %[[zero:.+]] = arith.constant 0 : i32
+// CHECK-NEXT: return %[[ARG]], %[[zero]]
+func.func @muluiExtendedOneLhs(%arg0: i32) -> (i32, i32) {
+ %zero = arith.constant 1 : i32
+ %low, %high = arith.mului_extended %zero, %arg0: i32
+ return %low, %high : i32, i32
+}
+
+// CHECK-LABEL: @muluiExtendedUnusedHigh
+// CHECK-SAME: (%[[ARG:.+]]: i32) -> i32
+// CHECK-NEXT: %[[RES:.+]] = arith.muli %[[ARG]], %[[ARG]] : i32
+// CHECK-NEXT: return %[[RES]]
+func.func @muluiExtendedUnusedHigh(%arg0: i32) -> i32 {
+ %low, %high = arith.mului_extended %arg0, %arg0: i32
+ return %low : i32
+}
+
+// This shouldn't be folded.
+// CHECK-LABEL: @muluiExtendedUnusedLow
+// CHECK-SAME: (%[[ARG:.+]]: i32) -> i32
+// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[ARG]], %[[ARG]] : i32
+// CHECK-NEXT: return %[[HIGH]]
+func.func @muluiExtendedUnusedLow(%arg0: i32) -> i32 {
+ %low, %high = arith.mului_extended %arg0, %arg0: i32
+ return %high : i32
+}
+
+// CHECK-LABEL: @muluiExtendedScalarConstants
+// CHECK-DAG: %[[c157:.+]] = arith.constant -99 : i8
+// CHECK-DAG: %[[c29:.+]] = arith.constant 29 : i8
+// CHECK-NEXT: return %[[c157]], %[[c29]]
+func.func @muluiExtendedScalarConstants() -> (i8, i8) {
+ %c57 = arith.constant 57 : i8
+ %c133 = arith.constant 133 : i8
+ %low, %high = arith.mului_extended %c57, %c133: i8 // = 7581
+ return %low, %high : i8, i8
+}
+
+// CHECK-LABEL: @muluiExtendedVectorConstants
+// CHECK-DAG: %[[cstLo:.+]] = arith.constant dense<[65, 79, 1]> : vector<3xi8>
+// CHECK-DAG: %[[cstHi:.+]] = arith.constant dense<[0, 14, -2]> : vector<3xi8>
+// CHECK-NEXT: return %[[cstLo]], %[[cstHi]]
+func.func @muluiExtendedVectorConstants() -> (vector<3xi8>, vector<3xi8>) {
+ %cstA = arith.constant dense<[5, 37, 255]> : vector<3xi8>
+ %cstB = arith.constant dense<[13, 99, 255]> : vector<3xi8>
+ %low, %high = arith.mului_extended %cstA, %cstB: vector<3xi8>
+ return %low, %high : vector<3xi8>, vector<3xi8>
+}
+
// CHECK-LABEL: @notCmpEQ
// CHECK: %[[cres:.+]] = arith.cmpi ne, %arg0, %arg1 : i8
// CHECK: return %[[cres]]
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 99a777d3d5f79..a9c0b557fdf01 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -97,6 +97,30 @@ func.func @test_muli_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]
return %0 : vector<[8]xi64>
}
+// CHECK-LABEL: test_mului_extended
+func.func @test_mului_extended(%arg0 : i32, %arg1 : i32) -> i32 {
+ %low, %high = arith.mului_extended %arg0, %arg1 : i32
+ return %high : i32
+}
+
+// CHECK-LABEL: test_mului_extended_tensor
+func.func @test_mului_extended_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> {
+ %low, %high = arith.mului_extended %arg0, %arg1 : tensor<8x8xi64>
+ return %high : tensor<8x8xi64>
+}
+
+// CHECK-LABEL: test_mului_extended_vector
+func.func @test_mului_extended_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> {
+ %0:2 = arith.mului_extended %arg0, %arg1 : vector<8xi64>
+ return %0#0 : vector<8xi64>
+}
+
+// CHECK-LABEL: test_mului_extended_scalable_vector
+func.func @test_mului_extended_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
+ %0:2 = arith.mului_extended %arg0, %arg1 : vector<[8]xi64>
+ return %0#1 : vector<[8]xi64>
+}
+
// CHECK-LABEL: test_divui
func.func @test_divui(%arg0 : i64, %arg1 : i64) -> i64 {
%0 = arith.divui %arg0, %arg1 : i64
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index cdb1536ce65bb..a13429ea8bfb3 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1044,7 +1044,8 @@ void PatternEmitter::emitRewriteLogic() {
}
if (offsets.front() > 0) {
- const char error[] = "no enough values generated to replace the matched op";
+ const char error[] =
+ "not enough values generated to replace the matched op";
PrintFatalError(loc, error);
}
More information about the Mlir-commits
mailing list