diff --git a/uniswap/uniswap.py b/uniswap/uniswap.py index 894971e..fc917d5 100644 --- a/uniswap/uniswap.py +++ b/uniswap/uniswap.py @@ -147,7 +147,9 @@ def __init__( self.w3, abi_name="uniswap-v3/router", address=self.router_address ) else: - raise Exception(f"Invalid version '{self.version}', only 1, 2 or 3 supported") + raise Exception( + f"Invalid version '{self.version}', only 1, 2 or 3 supported" + ) if hasattr(self, "factory_contract"): logger.info(f"Using factory contract: {self.factory_contract}") @@ -156,13 +158,13 @@ def __init__( def get_price_input( self, - token0: AddressLike, - token1: AddressLike, + token0: AddressLike, # input token + token1: AddressLike, # output token qty: int, fee: int = None, route: Optional[List[AddressLike]] = None, ) -> int: - """Given `qty` amount of `token0`, returns the maximum output amount of `token1`.""" + """Given `qty` amount of the input `token0`, returns the maximum output amount of output `token1`.""" if fee is None: fee = 3000 if self.version == 3: @@ -196,8 +198,13 @@ def get_price_output( else: return self._get_token_token_output_price(token0, token1, qty, fee, route) - def _get_eth_token_input_price(self, token: AddressLike, qty: Wei, fee: int) -> Wei: - """Public price for ETH to Token trades with an exact input.""" + def _get_eth_token_input_price( + self, + token: AddressLike, # output token + qty: Wei, + fee: int, + ) -> Wei: + """Public price (i.e. amount of output token received) for ETH to token trades with an exact input.""" if self.version == 1: ex = self._exchange_contract(token) price: Wei = ex.functions.getEthToTokenInputPrice(qty).call() @@ -211,8 +218,13 @@ def _get_eth_token_input_price(self, token: AddressLike, qty: Wei, fee: int) -> ) # type: ignore return price - def _get_token_eth_input_price(self, token: AddressLike, qty: int, fee: int) -> int: - """Public price for token to ETH trades with an exact input.""" + def _get_token_eth_input_price( + self, + token: AddressLike, # input token + qty: int, + fee: int, + ) -> int: + """Public price (i.e. amount of ETH received) for token to ETH trades with an exact input.""" if self.version == 1: ex = self._exchange_contract(token) price: int = ex.functions.getTokenToEthInputPrice(qty).call() @@ -228,14 +240,14 @@ def _get_token_eth_input_price(self, token: AddressLike, qty: int, fee: int) -> def _get_token_token_input_price( self, - token0: AddressLike, - token1: AddressLike, + token0: AddressLike, # input token + token1: AddressLike, # output token qty: int, fee: int, route: Optional[List[AddressLike]] = None, ) -> int: """ - Public price for token to token trades with an exact input. + Public price (i.e. amount of output token received) for token to token trades with an exact input. :param fee: (v3 only) The pool's fee in hundredths of a bip, i.e. 1e-6 (3000 is 0.3%) """ @@ -269,9 +281,12 @@ def _get_token_token_input_price( return price def _get_eth_token_output_price( - self, token: AddressLike, qty: int, fee: int = None + self, + token: AddressLike, # output token + qty: int, + fee: int = None, ) -> Wei: - """Public price for ETH to Token trades with an exact output.""" + """Public price (i.e. amount of ETH needed) for ETH to token trades with an exact output.""" if self.version == 1: ex = self._exchange_contract(token) price: Wei = ex.functions.getEthToTokenOutputPrice(qty).call() @@ -290,9 +305,9 @@ def _get_eth_token_output_price( return price def _get_token_eth_output_price( - self, token: AddressLike, qty: Wei, fee: int = None + self, token: AddressLike, qty: Wei, fee: int = None # input token ) -> int: - """Public price for token to ETH trades with an exact output.""" + """Public price (i.e. amount of input token needed) for token to ETH trades with an exact output.""" if self.version == 1: ex = self._exchange_contract(token) price: int = ex.functions.getTokenToEthOutputPrice(qty).call() @@ -310,14 +325,14 @@ def _get_token_eth_output_price( def _get_token_token_output_price( self, - token0: AddressLike, - token1: AddressLike, + token0: AddressLike, # input token + token1: AddressLike, # output token qty: int, fee: int = None, route: Optional[List[AddressLike]] = None, ) -> int: """ - Public price for token to token trades with an exact output. + Public price (i.e. amount of input token needed) for token to token trades with an exact output. :param fee: (v3 only) The pool's fee in hundredths of a bip, i.e. 1e-6 (3000 is 0.3%) """ @@ -376,28 +391,27 @@ def make_trade( if slippage is None: slippage = self.default_slippage + if input_token == output_token: + raise ValueError + if input_token == ETH_ADDRESS: return self._eth_to_token_swap_input( output_token, Wei(qty), recipient, fee, slippage, fee_on_transfer ) + elif output_token == ETH_ADDRESS: + return self._token_to_eth_swap_input( + input_token, qty, recipient, fee, slippage, fee_on_transfer + ) else: - balance = self.get_token_balance(input_token) - if balance < qty: - raise InsufficientBalance(balance, qty) - if output_token == ETH_ADDRESS: - return self._token_to_eth_swap_input( - input_token, qty, recipient, fee, slippage, fee_on_transfer - ) - else: - return self._token_to_token_swap_input( - input_token, - output_token, - qty, - recipient, - fee, - slippage, - fee_on_transfer, - ) + return self._token_to_token_swap_input( + input_token, + output_token, + qty, + recipient, + fee, + slippage, + fee_on_transfer, + ) @check_approval def make_trade_output( @@ -418,6 +432,9 @@ def make_trade_output( if slippage is None: slippage = self.default_slippage + if input_token == output_token: + raise ValueError + if input_token == ETH_ADDRESS: balance = self.get_eth_balance() need = self._get_eth_token_output_price(output_token, qty) @@ -427,9 +444,8 @@ def make_trade_output( output_token, qty, recipient, fee, slippage ) elif output_token == ETH_ADDRESS: - qty = Wei(qty) return self._token_to_eth_swap_output( - input_token, qty, recipient, fee, slippage + input_token, Wei(qty), recipient, fee, slippage ) else: return self._token_to_token_swap_output( @@ -446,6 +462,9 @@ def _eth_to_token_swap_input( fee_on_transfer: bool = False, ) -> HexBytes: """Convert ETH to tokens given an input amount.""" + if output_token == ETH_ADDRESS: + raise ValueError + eth_balance = self.get_eth_balance() if qty > eth_balance: raise InsufficientBalance(eth_balance, qty) @@ -483,10 +502,32 @@ def _eth_to_token_swap_input( self._get_tx_params(qty), ) elif self.version == 3: + if recipient is None: + recipient = self.address + if fee_on_transfer: raise Exception("fee on transfer not supported by Uniswap v3") - return self._token_to_token_swap_input( - self.get_weth_address(), output_token, qty, recipient, fee, slippage + + min_tokens_bought = int( + (1 - slippage) + * self._get_eth_token_input_price(output_token, qty, fee=fee) + ) + sqrtPriceLimitX96 = 0 + + return self._build_and_send_tx( + self.router.functions.exactInputSingle( + { + "tokenIn": self.get_weth_address(), + "tokenOut": output_token, + "fee": fee, + "recipient": recipient, + "deadline": self._deadline(), + "amountIn": qty, + "amountOutMinimum": min_tokens_bought, + "sqrtPriceLimitX96": sqrtPriceLimitX96, + } + ), + self._get_tx_params(value=qty), ) else: raise ValueError @@ -501,6 +542,9 @@ def _token_to_eth_swap_input( fee_on_transfer: bool = False, ) -> HexBytes: """Convert tokens to ETH given an input amount.""" + if input_token == ETH_ADDRESS: + raise ValueError + # Balance check input_balance = self.get_token_balance(input_token) if qty > input_balance: @@ -537,11 +581,45 @@ def _token_to_eth_swap_input( ), ) elif self.version == 3: + if recipient is None: + recipient = self.address + if fee_on_transfer: raise Exception("fee on transfer not supported by Uniswap v3") - return self._token_to_token_swap_input( - input_token, self.get_weth_address(), qty, recipient, fee, slippage + + output_token = self.get_weth_address() + min_tokens_bought = int( + (1 - slippage) + * self._get_token_eth_input_price(input_token, qty, fee=fee) + ) + sqrtPriceLimitX96 = 0 + + swap_data = self.router.encodeABI( + fn_name="exactInputSingle", + args=[ + ( + input_token, + output_token, + fee, + ETH_ADDRESS, + self._deadline(), + qty, + min_tokens_bought, + sqrtPriceLimitX96, + ) + ], ) + + unwrap_data = self.router.encodeABI( + fn_name="unwrapWETH9", args=[min_tokens_bought, recipient] + ) + + # Multicall + return self._build_and_send_tx( + self.router.functions.multicall([swap_data, unwrap_data]), + self._get_tx_params(), + ) + else: raise ValueError @@ -556,8 +634,19 @@ def _token_to_token_swap_input( fee_on_transfer: bool = False, ) -> HexBytes: """Convert tokens to tokens given an input amount.""" + # Balance check + input_balance = self.get_token_balance(input_token) + if qty > input_balance: + raise InsufficientBalance(input_balance, qty) + if recipient is None: recipient = self.address + + if input_token == ETH_ADDRESS: + raise ValueError + elif output_token == ETH_ADDRESS: + raise ValueError + if self.version == 1: token_funcs = self._exchange_contract(input_token).functions # TODO: This might not be correct @@ -602,6 +691,7 @@ def _token_to_token_swap_input( elif self.version == 3: if fee_on_transfer: raise Exception("fee on transfer not supported by Uniswap v3") + min_tokens_bought = int( (1 - slippage) * self._get_token_token_input_price( @@ -609,6 +699,7 @@ def _token_to_token_swap_input( ) ) sqrtPriceLimitX96 = 0 + return self._build_and_send_tx( self.router.functions.exactInputSingle( { @@ -622,9 +713,7 @@ def _token_to_token_swap_input( "sqrtPriceLimitX96": sqrtPriceLimitX96, } ), - self._get_tx_params( - Wei(qty) if input_token == self.get_weth_address() else Wei(0) - ), + self._get_tx_params(), ) else: raise ValueError @@ -638,6 +727,18 @@ def _eth_to_token_swap_output( slippage: float, ) -> HexBytes: """Convert ETH to tokens given an output amount.""" + if output_token == ETH_ADDRESS: + raise ValueError + + # Balance check + eth_balance = self.get_eth_balance() + cost = self._get_eth_token_output_price(output_token, qty, fee) + amount_in_max = Wei(int((1 + slippage) * cost)) + + # We check balance against amount_in_max rather than cost to be conservative + if amount_in_max > eth_balance: + raise InsufficientBalance(eth_balance, amount_in_max) + if self.version == 1: token_funcs = self._exchange_contract(output_token).functions eth_qty = self._get_eth_token_output_price(output_token, qty) @@ -666,8 +767,33 @@ def _eth_to_token_swap_output( self._get_tx_params(eth_qty), ) elif self.version == 3: - return self._token_to_token_swap_output( - self.get_weth_address(), output_token, qty, recipient, fee, slippage + if recipient is None: + recipient = self.address + + sqrtPriceLimitX96 = 0 + + swap_data = self.router.encodeABI( + fn_name="exactOutputSingle", + args=[ + ( + self.get_weth_address(), + output_token, + fee, + recipient, + self._deadline(), + qty, + amount_in_max, + sqrtPriceLimitX96, + ) + ], + ) + + refund_data = self.router.encodeABI(fn_name="refundETH", args=None) + + # Multicall + return self._build_and_send_tx( + self.router.functions.multicall([swap_data, refund_data]), + self._get_tx_params(value=amount_in_max), ) else: raise ValueError @@ -681,11 +807,17 @@ def _token_to_eth_swap_output( slippage: float, ) -> HexBytes: """Convert tokens to ETH given an output amount.""" + if input_token == ETH_ADDRESS: + raise ValueError + # Balance check input_balance = self.get_token_balance(input_token) cost = self._get_token_eth_output_price(input_token, qty, fee) - if cost > input_balance: - raise InsufficientBalance(input_balance, cost) + amount_in_max = int((1 + slippage) * cost) + + # We check balance against amount_in_max rather than cost to be conservative + if amount_in_max > input_balance: + raise InsufficientBalance(input_balance, amount_in_max) if self.version == 1: # From https://uniswap.org/docs/v1/frontend-integration/trade-tokens/ @@ -709,19 +841,49 @@ def _token_to_eth_swap_output( function = ex.functions.tokenToEthTransferOutput(*func_params) return self._build_and_send_tx(function) elif self.version == 2: + if recipient is None: + recipient = self.address + max_tokens = int((1 + slippage) * cost) return self._build_and_send_tx( self.router.functions.swapTokensForExactETH( qty, max_tokens, [input_token, self.get_weth_address()], - self.address, + recipient, self._deadline(), ), ) elif self.version == 3: - return self._token_to_token_swap_output( - input_token, self.get_weth_address(), qty, recipient, fee, slippage + if recipient is None: + recipient = self.address + + sqrtPriceLimitX96 = 0 + + swap_data = self.router.encodeABI( + fn_name="exactOutputSingle", + args=[ + ( + input_token, + self.get_weth_address(), + fee, + ETH_ADDRESS, + self._deadline(), + qty, + amount_in_max, + sqrtPriceLimitX96, + ) + ], + ) + + unwrap_data = self.router.encodeABI( + fn_name="unwrapWETH9", args=[qty, recipient] + ) + + # Multicall + return self._build_and_send_tx( + self.router.functions.multicall([swap_data, unwrap_data]), + self._get_tx_params(), ) else: raise ValueError @@ -735,11 +897,24 @@ def _token_to_token_swap_output( fee: int, slippage: float, ) -> HexBytes: - """ - Convert tokens to tokens given an output amount. + """Convert tokens to tokens given an output amount. :param fee: TODO """ + if input_token == ETH_ADDRESS: + raise ValueError + elif output_token == ETH_ADDRESS: + raise ValueError + + # Balance check + input_balance = self.get_token_balance(input_token) + cost = self._get_token_token_output_price(input_token, output_token, qty, fee) + amount_in_max = int((1 + slippage) * cost) + if ( + amount_in_max > input_balance + ): # We check balance against amount_in_max rather than cost to be conservative + raise InsufficientBalance(input_balance, amount_in_max) + if self.version == 1: token_funcs = self._exchange_contract(input_token).functions max_tokens_sold, max_eth_sold = self._calculate_max_input_token( @@ -779,10 +954,6 @@ def _token_to_token_swap_output( if recipient is None: recipient = self.address - cost = self._get_token_token_output_price( - input_token, output_token, qty, fee=fee - ) - amount_in_max = int((1 + slippage) * cost) sqrtPriceLimitX96 = 0 return self._build_and_send_tx( @@ -798,11 +969,7 @@ def _token_to_token_swap_output( "sqrtPriceLimitX96": sqrtPriceLimitX96, }, ), - self._get_tx_params( - Wei(amount_in_max) - if input_token == self.get_weth_address() - else Wei(0) - ), + self._get_tx_params(), ) else: raise ValueError