Skip to content

Commit

Permalink
feat: make virtual (#46)
Browse files Browse the repository at this point in the history
* feat: make virtual

* test: fix before amount

* test: auctioneer
  • Loading branch information
Schlagonia authored Aug 21, 2024
1 parent 6ce8d29 commit d732e91
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 16 deletions.
24 changes: 15 additions & 9 deletions src/swappers/TradeFactorySwapper.sol
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ abstract contract TradeFactorySwapper {
* proper functions to avoid issues.
* @return The current trade factory in use if any.
*/
function tradeFactory() public view returns (address) {
function tradeFactory() public view virtual returns (address) {
return _tradeFactory;
}

Expand All @@ -38,14 +38,17 @@ abstract contract TradeFactorySwapper {
* proper functions to avoid issues.
* @return The current array of tokens being sold if any.
*/
function rewardTokens() public view returns (address[] memory) {
function rewardTokens() public view virtual returns (address[] memory) {
return _rewardTokens;
}

/**
* @dev Add an array of tokens to sell to its corresponding `_to_.
*/
function _addTokens(address[] memory _from, address[] memory _to) internal {
function _addTokens(
address[] memory _from,
address[] memory _to
) internal virtual {
for (uint256 i; i < _from.length; ++i) {
_addToken(_from[i], _to[i]);
}
Expand All @@ -54,7 +57,7 @@ abstract contract TradeFactorySwapper {
/**
* @dev Add the `_tokenFrom` to be sold to `_tokenTo` through the Trade Factory
*/
function _addToken(address _tokenFrom, address _tokenTo) internal {
function _addToken(address _tokenFrom, address _tokenTo) internal virtual {
address _tf = tradeFactory();
if (_tf != address(0)) {
ERC20(_tokenFrom).forceApprove(_tf, type(uint256).max);
Expand All @@ -68,7 +71,10 @@ abstract contract TradeFactorySwapper {
* @dev Remove a specific `_tokenFrom` that was previously added to not be
* sold through the Trade Factory any more.
*/
function _removeToken(address _tokenFrom, address _tokenTo) internal {
function _removeToken(
address _tokenFrom,
address _tokenTo
) internal virtual {
address _tf = tradeFactory();
address[] memory _rewardTokensLocal = rewardTokens();
for (uint256 i; i < _rewardTokensLocal.length; ++i) {
Expand All @@ -95,7 +101,7 @@ abstract contract TradeFactorySwapper {
/**
* @dev Removes all reward tokens and delete the Trade Factory.
*/
function _deleteRewardTokens() internal {
function _deleteRewardTokens() internal virtual {
_removeTradeFactoryPermissions();
delete _rewardTokens;
}
Expand All @@ -109,7 +115,7 @@ abstract contract TradeFactorySwapper {
function _setTradeFactory(
address tradeFactory_,
address _tokenTo
) internal {
) internal virtual {
address _tf = tradeFactory();

// Remove any old Trade Factory
Expand All @@ -136,7 +142,7 @@ abstract contract TradeFactorySwapper {
/**
* @dev Remove any active approvals and set the trade factory to address(0).
*/
function _removeTradeFactoryPermissions() internal {
function _removeTradeFactoryPermissions() internal virtual {
address _tf = tradeFactory();
address[] memory rewardTokensLocal = rewardTokens();
for (uint256 i; i < rewardTokensLocal.length; ++i) {
Expand All @@ -149,7 +155,7 @@ abstract contract TradeFactorySwapper {
/**
* @notice Used for TradeFactory to claim rewards.
*/
function claimRewards() external {
function claimRewards() external virtual {
require(msg.sender == _tradeFactory, "!authorized");
_claimRewards();
}
Expand Down
9 changes: 6 additions & 3 deletions src/test/Auction.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ contract AuctionTest is Setup, ITaker {
skip(auction.auctionLength() / 2);

uint256 needed = auction.getAmountNeeded(id, _amount);
uint256 beforeAsset = ERC20(asset).balanceOf(address(this));

airdrop(ERC20(asset), address(this), needed);

Expand All @@ -311,7 +312,7 @@ contract AuctionTest is Setup, ITaker {
(, , , _available) = auction.auctionInfo(id);
assertEq(_available, 0);

assertEq(ERC20(asset).balanceOf(address(this)), 0);
assertEq(ERC20(asset).balanceOf(address(this)), beforeAsset);
assertEq(ERC20(from).balanceOf(address(this)), before + _amount);
assertEq(ERC20(from).balanceOf(address(auction)), 0);
assertEq(ERC20(asset).balanceOf(address(mockStrategy)), needed);
Expand Down Expand Up @@ -351,6 +352,7 @@ contract AuctionTest is Setup, ITaker {
uint256 toTake = (_amount * _percent) / MAX_BPS;
uint256 left = _amount - toTake;
uint256 needed = auction.getAmountNeeded(id, toTake);
uint256 beforeAsset = ERC20(asset).balanceOf(address(this));

airdrop(ERC20(asset), address(this), needed);

Expand All @@ -366,7 +368,7 @@ contract AuctionTest is Setup, ITaker {

(, , , _available) = auction.auctionInfo(id);
assertEq(_available, left);
assertEq(ERC20(asset).balanceOf(address(this)), 0);
assertEq(ERC20(asset).balanceOf(address(this)), beforeAsset);
assertEq(ERC20(from).balanceOf(address(this)), before + toTake);
assertEq(ERC20(from).balanceOf(address(auction)), left);
assertEq(ERC20(asset).balanceOf(address(mockStrategy)), needed);
Expand Down Expand Up @@ -405,6 +407,7 @@ contract AuctionTest is Setup, ITaker {
uint256 toTake = _amount / 2;
uint256 left = _amount - toTake;
uint256 needed = auction.getAmountNeeded(id, toTake);
uint256 beforeAsset = ERC20(asset).balanceOf(address(this));

airdrop(ERC20(asset), address(this), needed);

Expand All @@ -424,7 +427,7 @@ contract AuctionTest is Setup, ITaker {

(, , , _available) = auction.auctionInfo(id);
assertEq(_available, left);
assertEq(ERC20(asset).balanceOf(address(this)), 0);
assertEq(ERC20(asset).balanceOf(address(this)), beforeAsset);
assertEq(ERC20(from).balanceOf(address(this)), before + toTake);
assertEq(ERC20(from).balanceOf(address(auction)), left);
assertEq(ERC20(asset).balanceOf(address(mockStrategy)), needed);
Expand Down
6 changes: 4 additions & 2 deletions src/test/AuctionSwapper.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ contract AuctionSwapperTest is Setup {
uint256 toTake = (_amount * _percent) / MAX_BPS;
uint256 left = _amount - toTake;
uint256 needed = auction.getAmountNeeded(id, toTake);
uint256 beforeAsset = ERC20(asset).balanceOf(address(this));

airdrop(ERC20(asset), address(this), needed);

Expand All @@ -324,7 +325,7 @@ contract AuctionSwapperTest is Setup {

(, , , _available) = auction.auctionInfo(id);
assertEq(_available, left);
assertEq(ERC20(asset).balanceOf(address(this)), 0);
assertEq(ERC20(asset).balanceOf(address(this)), beforeAsset);
assertEq(ERC20(from).balanceOf(address(this)), before + toTake);
assertEq(ERC20(from).balanceOf(address(auction)), left);
assertEq(ERC20(asset).balanceOf(address(swapper)), needed);
Expand Down Expand Up @@ -470,6 +471,7 @@ contract AuctionSwapperTest is Setup {
uint256 toTake = (kickable * _percent) / MAX_BPS;
uint256 left = kickable - toTake;
uint256 needed = auction.getAmountNeeded(id, toTake);
uint256 beforeAsset = ERC20(asset).balanceOf(address(this));

airdrop(ERC20(asset), address(this), needed);

Expand All @@ -487,7 +489,7 @@ contract AuctionSwapperTest is Setup {

(, , , _available) = auction.auctionInfo(id);
assertEq(_available, left);
assertEq(ERC20(asset).balanceOf(address(this)), 0);
assertEq(ERC20(asset).balanceOf(address(this)), beforeAsset);
assertEq(ERC20(from).balanceOf(address(this)), before + toTake);
assertEq(ERC20(from).balanceOf(address(auction)), left);
assertEq(ERC20(asset).balanceOf(address(swapper)), needed);
Expand Down
6 changes: 4 additions & 2 deletions src/test/BaseAuctioneer.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ contract BaseAuctioneerTest is Setup {
uint256 toTake = (_amount * _percent) / MAX_BPS;
uint256 left = _amount - toTake;
uint256 needed = auctioneer.getAmountNeeded(id, toTake);
uint256 beforeAsset = ERC20(asset).balanceOf(address(this));

airdrop(ERC20(asset), address(this), needed);

Expand All @@ -291,7 +292,7 @@ contract BaseAuctioneerTest is Setup {

(, , , _available) = auctioneer.auctionInfo(id);
assertEq(_available, left);
assertEq(ERC20(asset).balanceOf(address(this)), 0);
assertEq(ERC20(asset).balanceOf(address(this)), beforeAsset);
assertEq(ERC20(from).balanceOf(address(this)), before + toTake);
assertEq(ERC20(from).balanceOf(address(auctioneer)), left);
assertEq(ERC20(asset).balanceOf(address(auctioneer)), needed);
Expand Down Expand Up @@ -430,6 +431,7 @@ contract BaseAuctioneerTest is Setup {
uint256 toTake = (kickable * _percent) / MAX_BPS;
uint256 left = _amount - toTake;
uint256 needed = auctioneer.getAmountNeeded(id, toTake);
uint256 beforeAsset = ERC20(asset).balanceOf(address(this));

airdrop(ERC20(asset), address(this), needed);

Expand All @@ -447,7 +449,7 @@ contract BaseAuctioneerTest is Setup {

(, , , _available) = auctioneer.auctionInfo(id);
assertEq(_available, kickable - toTake);
assertEq(ERC20(asset).balanceOf(address(this)), 0);
assertEq(ERC20(asset).balanceOf(address(this)), beforeAsset);
assertEq(ERC20(from).balanceOf(address(this)), before + toTake);
assertEq(ERC20(from).balanceOf(address(auctioneer)), left);
assertEq(ERC20(asset).balanceOf(address(auctioneer)), needed);
Expand Down

0 comments on commit d732e91

Please sign in to comment.