[Mlir-commits] [mlir] 7db18ef - [mlir][SPIR-V] Fold IAddCarry/[SU]MulExtended constants into spirv.Constant struct (#198633)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 20 02:08:44 PDT 2026
Author: Arseniy Obolenskiy
Date: 2026-05-20T11:08:37+02:00
New Revision: 7db18ef27472ff7d26f7029f7bf87511c63fda88
URL: https://github.com/llvm/llvm-project/commit/7db18ef27472ff7d26f7029f7bf87511c63fda88
DIFF: https://github.com/llvm/llvm-project/commit/7db18ef27472ff7d26f7029f7bf87511c63fda88.diff
LOG: [mlir][SPIR-V] Fold IAddCarry/[SU]MulExtended constants into spirv.Constant struct (#198633)
Emit a single spirv.Constant of struct type instead of the spirv.Undef +
two spirv.CompositeInsert workaround
Added:
Modified:
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 2bae0d23c9f00..e4b48ed680d52 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -125,17 +125,13 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
// spirv.IAddCarry
//===----------------------------------------------------------------------===//
-// We are required to use CompositeConstructOp to create a constant struct as
-// they are not yet implemented as constant, hence we can not do so in a fold.
struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
using Base::Base;
LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
PatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
Value lhs = op.getOperand1();
Value rhs = op.getOperand2();
- Type constituentType = lhs.getType();
// iaddcarry (x, 0) = <0, x>
if (matchPattern(rhs, m_Zero())) {
@@ -177,20 +173,8 @@ struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
if (!carrys)
return failure();
- Value addsVal =
- spirv::ConstantOp::create(rewriter, loc, constituentType, adds);
-
- Value carrysVal =
- spirv::ConstantOp::create(rewriter, loc, constituentType, carrys);
-
- // Create empty struct
- Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
- // Fill in adds at id 0
- Value intermediate =
- spirv::CompositeInsertOp::create(rewriter, loc, addsVal, undef, 0);
- // Fill in carrys at id 1
- rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, carrysVal,
- intermediate, 1);
+ rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+ op, op.getType(), rewriter.getArrayAttr({adds, carrys}));
return success();
}
};
@@ -204,8 +188,6 @@ void spirv::IAddCarryOp::getCanonicalizationPatterns(
// spirv.[S|U]MulExtended
//===----------------------------------------------------------------------===//
-// We are required to use CompositeConstructOp to create a constant struct as
-// they are not yet implemented as constant, hence we can not do so in a fold.
template <typename MulOp, bool IsSigned>
struct MulExtendedFold final : OpRewritePattern<MulOp> {
using OpRewritePattern<MulOp>::OpRewritePattern;
@@ -258,20 +240,8 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {
if (!highBits)
return failure();
- Value lowBitsVal =
- spirv::ConstantOp::create(rewriter, loc, constituentType, lowBits);
-
- Value highBitsVal =
- spirv::ConstantOp::create(rewriter, loc, constituentType, highBits);
-
- // Create empty struct
- Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
- // Fill in lowBits at id 0
- Value intermediate =
- spirv::CompositeInsertOp::create(rewriter, loc, lowBitsVal, undef, 0);
- // Fill in highBits at id 1
- rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, highBitsVal,
- intermediate, 1);
+ rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+ op, op.getType(), rewriter.getArrayAttr({lowBits, highBits}));
return success();
}
};
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 722c27586aa61..235ca15c08d7b 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -356,20 +356,12 @@ func.func @const_fold_scalar_iaddcarry() -> (!spirv.struct<(i32, i32)>, !spirv.s
%cn5 = spirv.Constant -5 : i32
%cn8 = spirv.Constant -8 : i32
- // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0
- // CHECK-DAG: %[[CN3:.*]] = spirv.Constant -3
- // CHECK-DAG: %[[UNDEF1:.*]] = spirv.Undef
- // CHECK-DAG: %[[INTER1:.*]] = spirv.CompositeInsert %[[CN3]], %[[UNDEF1]][0 : i32]
- // CHECK-DAG: %[[CC_CN3_C0:.*]] = spirv.CompositeInsert %[[C0]], %[[INTER1]][1 : i32]
- // CHECK-DAG: %[[C1:.*]] = spirv.Constant 1
- // CHECK-DAG: %[[CN13:.*]] = spirv.Constant -13
- // CHECK-DAG: %[[UNDEF2:.*]] = spirv.Undef
- // CHECK-DAG: %[[INTER2:.*]] = spirv.CompositeInsert %[[CN13]], %[[UNDEF2]][0 : i32]
- // CHECK-DAG: %[[CC_CN13_C1:.*]] = spirv.CompositeInsert %[[C1]], %[[INTER2]][1 : i32]
+ // CHECK-DAG: %[[CST_CN3_C0:.*]] = spirv.Constant [-3 : i32, 0 : i32] : !spirv.struct<(i32, i32)>
+ // CHECK-DAG: %[[CST_CN13_C1:.*]] = spirv.Constant [-13 : i32, 1 : i32] : !spirv.struct<(i32, i32)>
%0 = spirv.IAddCarry %c5, %cn8 : !spirv.struct<(i32, i32)>
%1 = spirv.IAddCarry %cn5, %cn8 : !spirv.struct<(i32, i32)>
- // CHECK: return %[[CC_CN3_C0]], %[[CC_CN13_C1]]
+ // CHECK: return %[[CST_CN3_C0]], %[[CST_CN13_C1]]
return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
}
@@ -378,14 +370,10 @@ func.func @const_fold_vector_iaddcarry() -> !spirv.struct<(vector<3xi32>, vector
%v0 = spirv.Constant dense<[5, -3, -1]> : vector<3xi32>
%v1 = spirv.Constant dense<[-8, -8, 1]> : vector<3xi32>
- // CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[-3, -11, 0]>
- // CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[0, 1, 1]>
- // CHECK-DAG: %[[UNDEF:.*]] = spirv.Undef
- // CHECK-DAG: %[[INTER:.*]] = spirv.CompositeInsert %[[CV1]], %[[UNDEF]][0 : i32]
- // CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeInsert %[[CV2]], %[[INTER]][1 : i32]
+ // CHECK: %[[CST:.*]] = spirv.Constant [dense<[-3, -11, 0]> : vector<3xi32>, dense<[0, 1, 1]> : vector<3xi32>] : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
%0 = spirv.IAddCarry %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
- // CHECK: return %[[CC_CV1_CV2]]
+ // CHECK: return %[[CST]]
return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
}
@@ -476,20 +464,12 @@ func.func @const_fold_scalar_smulextended() -> (!spirv.struct<(i32, i32)>, !spir
%cn5 = spirv.Constant -5 : i32
%cn8 = spirv.Constant -8 : i32
- // CHECK-DAG: %[[CN40:.*]] = spirv.Constant -40
- // CHECK-DAG: %[[CN1:.*]] = spirv.Constant -1
- // CHECK-DAG: %[[UNDEF1:.*]] = spirv.Undef
- // CHECK-DAG: %[[INTER1:.*]] = spirv.CompositeInsert %[[CN40]], %[[UNDEF1]][0 : i32]
- // CHECK-DAG: %[[CC_CN40_CN1:.*]] = spirv.CompositeInsert %[[CN1]], %[[INTER1]]
- // CHECK-DAG: %[[C40:.*]] = spirv.Constant 40
- // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0
- // CHECK-DAG: %[[UNDEF2:.*]] = spirv.Undef
- // CHECK-DAG: %[[INTER2:.*]] = spirv.CompositeInsert %[[C40]], %[[UNDEF2]][0 : i32]
- // CHECK-DAG: %[[CC_C40_C0:.*]] = spirv.CompositeInsert %[[C0]], %[[INTER2]][1 : i32]
+ // CHECK-DAG: %[[CST_CN40_CN1:.*]] = spirv.Constant [-40 : i32, -1 : i32] : !spirv.struct<(i32, i32)>
+ // CHECK-DAG: %[[CST_C40_C0:.*]] = spirv.Constant [40 : i32, 0 : i32] : !spirv.struct<(i32, i32)>
%0 = spirv.SMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
%1 = spirv.SMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
- // CHECK: return %[[CC_CN40_CN1]], %[[CC_C40_C0]]
+ // CHECK: return %[[CST_CN40_CN1]], %[[CST_C40_C0]]
return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
}
@@ -498,14 +478,10 @@ func.func @const_fold_vector_smulextended() -> !spirv.struct<(vector<3xi32>, vec
%v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
%v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
- // CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[2147483643, 40, -1]>
- // CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[2, 0, -1]>
- // CHECK-DAG: %[[UNDEF:.*]] = spirv.Undef
- // CHECK-DAG: %[[INTER:.*]] = spirv.CompositeInsert %[[CV1]], %[[UNDEF]][0 : i32]
- // CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeInsert %[[CV2]], %[[INTER]][1 : i32]
+ // CHECK: %[[CST:.*]] = spirv.Constant [dense<[2147483643, 40, -1]> : vector<3xi32>, dense<[2, 0, -1]> : vector<3xi32>] : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
%0 = spirv.SMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
- // CHECK: return %[[CC_CV1_CV2]]
+ // CHECK: return %[[CST]]
return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
}
@@ -545,21 +521,12 @@ func.func @const_fold_scalar_umulextended() -> (!spirv.struct<(i32, i32)>, !spir
%cn5 = spirv.Constant -5 : i32
%cn8 = spirv.Constant -8 : i32
-
- // CHECK-DAG: %[[C40:.*]] = spirv.Constant 40
- // CHECK-DAG: %[[CN13:.*]] = spirv.Constant -13
- // CHECK-DAG: %[[CN40:.*]] = spirv.Constant -40
- // CHECK-DAG: %[[C4:.*]] = spirv.Constant 4
- // CHECK-DAG: %[[UNDEF1:.*]] = spirv.Undef
- // CHECK-DAG: %[[INTER1:.*]] = spirv.CompositeInsert %[[CN40]], %[[UNDEF1]][0 : i32]
- // CHECK-DAG: %[[CC_CN40_C4:.*]] = spirv.CompositeInsert %[[C4]], %[[INTER1]][1 : i32]
- // CHECK-DAG: %[[UNDEF2:.*]] = spirv.Undef
- // CHECK-DAG: %[[INTER2:.*]] = spirv.CompositeInsert %[[C40]], %[[UNDEF2]][0 : i32]
- // CHECK-DAG: %[[CC_C40_CN13:.*]] = spirv.CompositeInsert %[[CN13]], %[[INTER2]][1 : i32]
+ // CHECK-DAG: %[[CST_CN40_C4:.*]] = spirv.Constant [-40 : i32, 4 : i32] : !spirv.struct<(i32, i32)>
+ // CHECK-DAG: %[[CST_C40_CN13:.*]] = spirv.Constant [40 : i32, -13 : i32] : !spirv.struct<(i32, i32)>
%0 = spirv.UMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
%1 = spirv.UMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
- // CHECK: return %[[CC_CN40_C4]], %[[CC_C40_CN13]]
+ // CHECK: return %[[CST_CN40_C4]], %[[CST_C40_CN13]]
return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
}
@@ -568,14 +535,10 @@ func.func @const_fold_vector_umulextended() -> !spirv.struct<(vector<3xi32>, vec
%v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
%v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
- // CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[2147483643, 40, -1]>
- // CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[2, -13, 0]>
- // CHECK-DAG: %[[UNDEF:.*]] = spirv.Undef
- // CHECK-DAG: %[[INTER:.*]] = spirv.CompositeInsert %[[CV1]], %[[UNDEF]]
- // CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeInsert %[[CV2]], %[[INTER]]
+ // CHECK: %[[CST:.*]] = spirv.Constant [dense<[2147483643, 40, -1]> : vector<3xi32>, dense<[2, -13, 0]> : vector<3xi32>] : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
%0 = spirv.UMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
- // CHECK: return %[[CC_CV1_CV2]]
+ // CHECK: return %[[CST]]
return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
}
More information about the Mlir-commits
mailing list