diff --git a/src/swappers/TradeFactorySwapper.sol b/src/swappers/TradeFactorySwapper.sol index 77cfcb9..ae794f9 100644 --- a/src/swappers/TradeFactorySwapper.sol +++ b/src/swappers/TradeFactorySwapper.sol @@ -28,7 +28,7 @@ abstract contract TradeFactorySwapper { * proper functions to avoid issues. * @return The current trade factory in use if any. */ - function tradeFactory() public view returns (address) { + function tradeFactory() public view virtual returns (address) { return _tradeFactory; } @@ -38,14 +38,17 @@ abstract contract TradeFactorySwapper { * proper functions to avoid issues. * @return The current array of tokens being sold if any. */ - function rewardTokens() public view returns (address[] memory) { + function rewardTokens() public view virtual returns (address[] memory) { return _rewardTokens; } /** * @dev Add an array of tokens to sell to its corresponding `_to_. */ - function _addTokens(address[] memory _from, address[] memory _to) internal { + function _addTokens( + address[] memory _from, + address[] memory _to + ) internal virtual { for (uint256 i; i < _from.length; ++i) { _addToken(_from[i], _to[i]); } @@ -54,7 +57,7 @@ abstract contract TradeFactorySwapper { /** * @dev Add the `_tokenFrom` to be sold to `_tokenTo` through the Trade Factory */ - function _addToken(address _tokenFrom, address _tokenTo) internal { + function _addToken(address _tokenFrom, address _tokenTo) internal virtual { address _tf = tradeFactory(); if (_tf != address(0)) { ERC20(_tokenFrom).forceApprove(_tf, type(uint256).max); @@ -68,7 +71,10 @@ abstract contract TradeFactorySwapper { * @dev Remove a specific `_tokenFrom` that was previously added to not be * sold through the Trade Factory any more. */ - function _removeToken(address _tokenFrom, address _tokenTo) internal { + function _removeToken( + address _tokenFrom, + address _tokenTo + ) internal virtual { address _tf = tradeFactory(); address[] memory _rewardTokensLocal = rewardTokens(); for (uint256 i; i < _rewardTokensLocal.length; ++i) { @@ -95,7 +101,7 @@ abstract contract TradeFactorySwapper { /** * @dev Removes all reward tokens and delete the Trade Factory. */ - function _deleteRewardTokens() internal { + function _deleteRewardTokens() internal virtual { _removeTradeFactoryPermissions(); delete _rewardTokens; } @@ -109,7 +115,7 @@ abstract contract TradeFactorySwapper { function _setTradeFactory( address tradeFactory_, address _tokenTo - ) internal { + ) internal virtual { address _tf = tradeFactory(); // Remove any old Trade Factory @@ -136,7 +142,7 @@ abstract contract TradeFactorySwapper { /** * @dev Remove any active approvals and set the trade factory to address(0). */ - function _removeTradeFactoryPermissions() internal { + function _removeTradeFactoryPermissions() internal virtual { address _tf = tradeFactory(); address[] memory rewardTokensLocal = rewardTokens(); for (uint256 i; i < rewardTokensLocal.length; ++i) { @@ -149,7 +155,7 @@ abstract contract TradeFactorySwapper { /** * @notice Used for TradeFactory to claim rewards. */ - function claimRewards() external { + function claimRewards() external virtual { require(msg.sender == _tradeFactory, "!authorized"); _claimRewards(); } diff --git a/src/test/Auction.t.sol b/src/test/Auction.t.sol index dce5279..826426e 100644 --- a/src/test/Auction.t.sol +++ b/src/test/Auction.t.sol @@ -295,6 +295,7 @@ contract AuctionTest is Setup, ITaker { skip(auction.auctionLength() / 2); uint256 needed = auction.getAmountNeeded(id, _amount); + uint256 beforeAsset = ERC20(asset).balanceOf(address(this)); airdrop(ERC20(asset), address(this), needed); @@ -311,7 +312,7 @@ contract AuctionTest is Setup, ITaker { (, , , _available) = auction.auctionInfo(id); assertEq(_available, 0); - assertEq(ERC20(asset).balanceOf(address(this)), 0); + assertEq(ERC20(asset).balanceOf(address(this)), beforeAsset); assertEq(ERC20(from).balanceOf(address(this)), before + _amount); assertEq(ERC20(from).balanceOf(address(auction)), 0); assertEq(ERC20(asset).balanceOf(address(mockStrategy)), needed); @@ -351,6 +352,7 @@ contract AuctionTest is Setup, ITaker { uint256 toTake = (_amount * _percent) / MAX_BPS; uint256 left = _amount - toTake; uint256 needed = auction.getAmountNeeded(id, toTake); + uint256 beforeAsset = ERC20(asset).balanceOf(address(this)); airdrop(ERC20(asset), address(this), needed); @@ -366,7 +368,7 @@ contract AuctionTest is Setup, ITaker { (, , , _available) = auction.auctionInfo(id); assertEq(_available, left); - assertEq(ERC20(asset).balanceOf(address(this)), 0); + assertEq(ERC20(asset).balanceOf(address(this)), beforeAsset); assertEq(ERC20(from).balanceOf(address(this)), before + toTake); assertEq(ERC20(from).balanceOf(address(auction)), left); assertEq(ERC20(asset).balanceOf(address(mockStrategy)), needed); @@ -405,6 +407,7 @@ contract AuctionTest is Setup, ITaker { uint256 toTake = _amount / 2; uint256 left = _amount - toTake; uint256 needed = auction.getAmountNeeded(id, toTake); + uint256 beforeAsset = ERC20(asset).balanceOf(address(this)); airdrop(ERC20(asset), address(this), needed); @@ -424,7 +427,7 @@ contract AuctionTest is Setup, ITaker { (, , , _available) = auction.auctionInfo(id); assertEq(_available, left); - assertEq(ERC20(asset).balanceOf(address(this)), 0); + assertEq(ERC20(asset).balanceOf(address(this)), beforeAsset); assertEq(ERC20(from).balanceOf(address(this)), before + toTake); assertEq(ERC20(from).balanceOf(address(auction)), left); assertEq(ERC20(asset).balanceOf(address(mockStrategy)), needed); diff --git a/src/test/AuctionSwapper.t.sol b/src/test/AuctionSwapper.t.sol index d8fb5bd..cc6be41 100644 --- a/src/test/AuctionSwapper.t.sol +++ b/src/test/AuctionSwapper.t.sol @@ -309,6 +309,7 @@ contract AuctionSwapperTest is Setup { uint256 toTake = (_amount * _percent) / MAX_BPS; uint256 left = _amount - toTake; uint256 needed = auction.getAmountNeeded(id, toTake); + uint256 beforeAsset = ERC20(asset).balanceOf(address(this)); airdrop(ERC20(asset), address(this), needed); @@ -324,7 +325,7 @@ contract AuctionSwapperTest is Setup { (, , , _available) = auction.auctionInfo(id); assertEq(_available, left); - assertEq(ERC20(asset).balanceOf(address(this)), 0); + assertEq(ERC20(asset).balanceOf(address(this)), beforeAsset); assertEq(ERC20(from).balanceOf(address(this)), before + toTake); assertEq(ERC20(from).balanceOf(address(auction)), left); assertEq(ERC20(asset).balanceOf(address(swapper)), needed); @@ -470,6 +471,7 @@ contract AuctionSwapperTest is Setup { uint256 toTake = (kickable * _percent) / MAX_BPS; uint256 left = kickable - toTake; uint256 needed = auction.getAmountNeeded(id, toTake); + uint256 beforeAsset = ERC20(asset).balanceOf(address(this)); airdrop(ERC20(asset), address(this), needed); @@ -487,7 +489,7 @@ contract AuctionSwapperTest is Setup { (, , , _available) = auction.auctionInfo(id); assertEq(_available, left); - assertEq(ERC20(asset).balanceOf(address(this)), 0); + assertEq(ERC20(asset).balanceOf(address(this)), beforeAsset); assertEq(ERC20(from).balanceOf(address(this)), before + toTake); assertEq(ERC20(from).balanceOf(address(auction)), left); assertEq(ERC20(asset).balanceOf(address(swapper)), needed); diff --git a/src/test/BaseAuctioneer.t.sol b/src/test/BaseAuctioneer.t.sol index c689418..ea9df23 100644 --- a/src/test/BaseAuctioneer.t.sol +++ b/src/test/BaseAuctioneer.t.sol @@ -276,6 +276,7 @@ contract BaseAuctioneerTest is Setup { uint256 toTake = (_amount * _percent) / MAX_BPS; uint256 left = _amount - toTake; uint256 needed = auctioneer.getAmountNeeded(id, toTake); + uint256 beforeAsset = ERC20(asset).balanceOf(address(this)); airdrop(ERC20(asset), address(this), needed); @@ -291,7 +292,7 @@ contract BaseAuctioneerTest is Setup { (, , , _available) = auctioneer.auctionInfo(id); assertEq(_available, left); - assertEq(ERC20(asset).balanceOf(address(this)), 0); + assertEq(ERC20(asset).balanceOf(address(this)), beforeAsset); assertEq(ERC20(from).balanceOf(address(this)), before + toTake); assertEq(ERC20(from).balanceOf(address(auctioneer)), left); assertEq(ERC20(asset).balanceOf(address(auctioneer)), needed); @@ -430,6 +431,7 @@ contract BaseAuctioneerTest is Setup { uint256 toTake = (kickable * _percent) / MAX_BPS; uint256 left = _amount - toTake; uint256 needed = auctioneer.getAmountNeeded(id, toTake); + uint256 beforeAsset = ERC20(asset).balanceOf(address(this)); airdrop(ERC20(asset), address(this), needed); @@ -447,7 +449,7 @@ contract BaseAuctioneerTest is Setup { (, , , _available) = auctioneer.auctionInfo(id); assertEq(_available, kickable - toTake); - assertEq(ERC20(asset).balanceOf(address(this)), 0); + assertEq(ERC20(asset).balanceOf(address(this)), beforeAsset); assertEq(ERC20(from).balanceOf(address(this)), before + toTake); assertEq(ERC20(from).balanceOf(address(auctioneer)), left); assertEq(ERC20(asset).balanceOf(address(auctioneer)), needed);