diff --git a/python/thirdweb-ai/pyproject.toml b/python/thirdweb-ai/pyproject.toml index dbff462..2f0da58 100644 --- a/python/thirdweb-ai/pyproject.toml +++ b/python/thirdweb-ai/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "thirdweb-ai" -version = "0.1.9" +version = "0.1.10" description = "thirdweb AI" authors = [{ name = "thirdweb", email = "support@thirdweb.com" }] requires-python = ">=3.10,<4.0" @@ -20,6 +20,7 @@ dependencies = [ "jsonref>=1.1.0,<2", "httpx>=0.28.1,<0.29", "aiohttp>=3.11.14", + "web3>=7.9.0", ] [project.optional-dependencies] diff --git a/python/thirdweb-ai/src/thirdweb_ai/common/address.py b/python/thirdweb-ai/src/thirdweb_ai/common/address.py new file mode 100644 index 0000000..f82dd33 --- /dev/null +++ b/python/thirdweb-ai/src/thirdweb_ai/common/address.py @@ -0,0 +1,52 @@ +import re + +from web3 import Web3 + + +def validate_address(address: str) -> str: + if not address.startswith("0x") or len(address) != 42: + raise ValueError(f"Invalid blockchain address format: {address}") + + if not Web3.is_checksum_address(address): + try: + return Web3.to_checksum_address(address) + except ValueError as e: + raise ValueError(f"Invalid blockchain address: {address}") from e + + return address + + +def validate_transaction_hash(tx_hash: str) -> str: + pattern = re.compile(r"^0x[a-fA-F0-9]{64}$") + if bool(re.fullmatch(pattern, tx_hash)): + return tx_hash + raise ValueError(f"Invalid transaction hash: {tx_hash}") + + +def validate_block_identifier(block_id: str) -> str: + if block_id.startswith("0x"): + pattern = re.compile(r"^0x[a-fA-F0-9]{64}$") + if bool(re.fullmatch(pattern, block_id)): + return block_id + elif block_id.isdigit(): + return block_id + + raise ValueError(f"Invalid block identifier: {block_id}") + + +def validate_signature(signature: str) -> str: + # Function selector (4 bytes) + if signature.startswith("0x") and len(signature) == 10: + pattern = re.compile(r"^0x[a-fA-F0-9]{8}$") + if bool(re.fullmatch(pattern, signature)): + return signature + # Event topic (32 bytes) + elif signature.startswith("0x") and len(signature) == 66: + pattern = re.compile(r"^0x[a-fA-F0-9]{64}$") + if bool(re.fullmatch(pattern, signature)): + return signature + # Plain text signature (e.g. "transfer(address,uint256)") + elif "(" in signature and ")" in signature: + return signature + + raise ValueError(f"Invalid function or event signature: {signature}") diff --git a/python/thirdweb-ai/src/thirdweb_ai/common/utils.py b/python/thirdweb-ai/src/thirdweb_ai/common/utils.py index a664415..8d02f9f 100644 --- a/python/thirdweb-ai/src/thirdweb_ai/common/utils.py +++ b/python/thirdweb-ai/src/thirdweb_ai/common/utils.py @@ -1,8 +1,33 @@ import re from typing import Any +TRANSACTION_KEYS_TO_KEEP = [ + "hash", + "block_number", + "block_timestamp", + "from_address", + "to_address", + "value", + "decodedData", +] +EVENT_KEYS_TO_KEEP = [ + "block_number", + "block_timestamp", + "address", + "transaction_hash", + "transaction_index", + "log_index", + "topics", + "data", + "decodedData", +] + def extract_digits(value: int | str) -> int: + """Extract the integer value from a string or return the integer directly.""" + if isinstance(value, int): + return value + value_str = str(value).strip("\"'") digit_match = re.search(r"\d+", value_str) @@ -16,21 +41,8 @@ def extract_digits(value: int | str) -> int: return int(extracted_digits) -def normalize_chain_id( - in_value: int | str | list[int | str] | None, -) -> int | list[int] | None: - """Normalize str values integers.""" - - if in_value is None: - return None - - if isinstance(in_value, list): - return [extract_digits(c) for c in in_value] - - return extract_digits(in_value) - - def is_encoded(encoded_data: str) -> bool: + """Check if a string is a valid hexadecimal value.""" encoded_data = encoded_data.removeprefix("0x") try: @@ -41,6 +53,7 @@ def is_encoded(encoded_data: str) -> bool: def clean_resolve(out: dict[str, Any]): + """Clean the response from the resolve function.""" if "transactions" in out["data"]: for transaction in out["data"]["transactions"]: if "data" in transaction and is_encoded(transaction["data"]): @@ -48,3 +61,15 @@ def clean_resolve(out: dict[str, Any]): if "logs_bloom" in transaction: transaction.pop("logs_bloom") return out + + +def filter_response_keys(items: list[dict[str, Any]], keys_to_keep: list[str] | None) -> list[dict[str, Any]]: + """Filter the response items to only include the specified keys""" + if not keys_to_keep: + return items + + for item in items: + keys_to_remove = [key for key in item if key not in keys_to_keep] + for key in keys_to_remove: + item.pop(key, None) + return items diff --git a/python/thirdweb-ai/src/thirdweb_ai/services/engine.py b/python/thirdweb-ai/src/thirdweb_ai/services/engine.py index 1c00ee4..c1c3f1d 100644 --- a/python/thirdweb-ai/src/thirdweb_ai/services/engine.py +++ b/python/thirdweb-ai/src/thirdweb_ai/services/engine.py @@ -1,6 +1,6 @@ -from typing import Annotated, Any +from typing import Annotated, Any, Literal -from thirdweb_ai.common.utils import extract_digits, normalize_chain_id +from thirdweb_ai.common.utils import extract_digits from thirdweb_ai.services.service import Service from thirdweb_ai.tools.tool import tool @@ -10,7 +10,7 @@ def __init__( self, engine_url: str, engine_auth_jwt: str, - chain_id: int | str | None = None, + chain_id: int | None = None, backend_wallet_address: str | None = None, secret_key: str = "", ): @@ -18,7 +18,7 @@ def __init__( self.engine_url = engine_url self.engine_auth_jwt = engine_auth_jwt self.backend_wallet_address = backend_wallet_address - self.chain_id = normalize_chain_id(chain_id) + self.chain_id = chain_id def _make_headers(self): headers = super()._make_headers() @@ -34,7 +34,7 @@ def _make_headers(self): def create_backend_wallet( self, wallet_type: Annotated[ - str, + Literal["local", "smart:local"], "The type of backend wallet to create. Currently supported options are 'local' (stored locally in Engine's database) or 'smart:local' (for smart account wallets with advanced features). Choose 'local' for standard EOA wallets, and 'smart:local' for smart contract wallets with batching capabilities.", ], label: Annotated[ @@ -76,7 +76,7 @@ def get_all_backend_wallet( def get_wallet_balance( self, chain_id: Annotated[ - str | int, + int | None, "The numeric blockchain network ID to query (e.g., '1' for Ethereum mainnet, '137' for Polygon). If not provided, uses the default chain ID configured in the Engine instance.", ], backend_wallet_address: Annotated[ @@ -85,9 +85,13 @@ def get_wallet_balance( ] = None, ) -> dict[str, Any]: """Get wallet balance for native or ERC20 tokens.""" - normalized_chain = normalize_chain_id(chain_id) or self.chain_id + if self.chain_id is not None and chain_id is None: + chain_id = self.chain_id + elif chain_id is None: + raise ValueError("chain_id is required") + backend_wallet_address = backend_wallet_address or self.backend_wallet_address - return self._get(f"backend-wallet/{normalized_chain}/{backend_wallet_address}/get-balance") + return self._get(f"backend-wallet/{chain_id}/{backend_wallet_address}/get-balance") @tool( description="Send an on-chain transaction. This powerful function can transfer native currency (ETH, MATIC), ERC20 tokens, or execute any arbitrary contract interaction. The transaction is signed and broadcast to the blockchain automatically." @@ -107,7 +111,7 @@ def send_transaction( "The hexadecimal transaction data payload for contract interactions (e.g., '0x23b872dd...'). For simple native currency transfers, leave this empty. For ERC20 transfers or contract calls, this contains the ABI-encoded function call.", ], chain_id: Annotated[ - str | int, + int | None, "The numeric blockchain network ID to send the transaction on (e.g., '1' for Ethereum mainnet, '137' for Polygon). If not provided, uses the default chain ID configured in the Engine instance.", ], backend_wallet_address: Annotated[ @@ -126,10 +130,14 @@ def send_transaction( "data": data or "0x", } - normalized_chain = normalize_chain_id(chain_id) or self.chain_id + if self.chain_id is not None and chain_id is None: + chain_id = self.chain_id + elif chain_id is None: + raise ValueError("chain_id is required") + backend_wallet_address = backend_wallet_address or self.backend_wallet_address return self._post( - f"backend-wallet/{normalized_chain}/send-transaction", + f"backend-wallet/{chain_id}/send-transaction", payload, headers={"X-Backend-Wallet-Address": backend_wallet_address}, ) @@ -165,7 +173,7 @@ def read_contract( "An ordered list of arguments to pass to the function (e.g., [address, tokenId]). Must match the types and order expected by the function. For functions with no parameters, use an empty list or None.", ], chain_id: Annotated[ - str | int, + int | None, "The numeric blockchain network ID where the contract is deployed (e.g., '1' for Ethereum mainnet, '137' for Polygon). If not provided, uses the default chain ID configured in the Engine instance.", ], ) -> dict[str, Any]: @@ -174,8 +182,12 @@ def read_contract( "functionName": function_name, "args": function_args or [], } - normalized_chain = normalize_chain_id(chain_id) or self.chain_id - return self._get(f"contract/{normalized_chain}/{contract_address}/read", payload) + if self.chain_id is not None and chain_id is None: + chain_id = self.chain_id + elif chain_id is None: + raise ValueError("chain_id is required") + + return self._get(f"contract/{chain_id}/{contract_address}/read", payload) @tool( description="Execute a state-changing function on a smart contract by sending a transaction. This allows you to modify on-chain data, such as transferring tokens, minting NFTs, or updating contract configuration. The transaction is automatically signed by your backend wallet and submitted to the blockchain." @@ -199,7 +211,7 @@ def write_contract( "The amount of native currency (ETH, MATIC, etc.) to send with the transaction, in wei (e.g., '1000000000000000000' for 1 ETH). Required for payable functions, use '0' for non-payable functions. Default to '0'.", ], chain_id: Annotated[ - str | int, + int | None, "The numeric blockchain network ID where the contract is deployed (e.g., '1' for Ethereum mainnet, '137' for Polygon). If not provided, uses the default chain ID configured in the Engine instance.", ], ) -> dict[str, Any]: @@ -212,9 +224,13 @@ def write_contract( if value and value != "0": payload["txOverrides"] = {"value": value} - normalized_chain = normalize_chain_id(chain_id) or self.chain_id + if self.chain_id is not None and chain_id is None: + chain_id = self.chain_id + elif chain_id is None: + raise ValueError("chain_id is required") + return self._post( - f"contract/{normalized_chain}/{contract_address}/write", + f"contract/{chain_id}/{contract_address}/write", payload, headers={"X-Backend-Wallet-Address": self.backend_wallet_address}, ) diff --git a/python/thirdweb-ai/src/thirdweb_ai/services/insight.py b/python/thirdweb-ai/src/thirdweb_ai/services/insight.py index 3d07960..94f6cba 100644 --- a/python/thirdweb-ai/src/thirdweb_ai/services/insight.py +++ b/python/thirdweb-ai/src/thirdweb_ai/services/insight.py @@ -1,75 +1,60 @@ -from typing import Annotated, Any +from typing import Annotated, Any, Literal -from thirdweb_ai.common.utils import clean_resolve, normalize_chain_id +from thirdweb_ai.common.address import ( + validate_address, + validate_block_identifier, + validate_signature, + validate_transaction_hash, +) +from thirdweb_ai.common.utils import EVENT_KEYS_TO_KEEP, TRANSACTION_KEYS_TO_KEEP, clean_resolve, filter_response_keys from thirdweb_ai.services.service import Service from thirdweb_ai.tools.tool import tool class Insight(Service): - def __init__(self, secret_key: str, chain_id: int | str | list[int | str]): + def __init__(self, secret_key: str, chain_id: int | list[int] | None = None): super().__init__(base_url="https://insight.thirdweb.com/v1", secret_key=secret_key) - normalized = normalize_chain_id(chain_id) - self.chain_ids = normalized if isinstance(normalized, list) else [normalized] + self.chain_ids = [chain_id or 1] @tool( description="Retrieve blockchain events with flexible filtering options. Use this to search for specific events or to analyze event patterns across multiple blocks. Do not use this tool to simply look up a single transaction." ) - def get_all_events( + def get_events( self, - chain: Annotated[ - list[int | str] | int | str | None, + chain_id: Annotated[ + list[int] | int | None, "Chain ID(s) to query (e.g., 1 for Ethereum Mainnet, 137 for Polygon). Specify multiple IDs as a list [1, 137] for cross-chain queries (max 5).", ] = None, contract_address: Annotated[ str | None, "Contract address to filter events by (e.g., '0x1234...'). Only return events emitted by this contract.", ] = None, - block_number_gte: Annotated[int | None, "Minimum block number to start querying from (inclusive)."] = None, - block_number_lt: Annotated[int | None, "Maximum block number to query up to (exclusive)."] = None, transaction_hash: Annotated[ str | None, "Specific transaction hash to filter events by (e.g., '0xabc123...'). Useful for examining events in a particular transaction.", ] = None, - topic_0: Annotated[ - str | None, - "Filter by event signature hash (first topic). For example, '0xa6697e974e6a320f454390be03f74955e8978f1a6971ea6730542e37b66179bc' for Transfer events.", - ] = None, - limit: Annotated[ - int | None, - "Maximum number of events to return per request. Default is 20, adjust for pagination.", - ] = None, page: Annotated[ int | None, - "Page number for paginated results, starting from 0. Use with limit parameter.", + "Page number for paginated results, starting from 0. 20 results are returned per page.", ] = None, - sort_order: Annotated[ - str | None, - "Sort order for the events. Default is 'desc' for descending order. Use 'asc' for ascending order.", - ] = "desc", ) -> dict[str, Any]: params: dict[str, Any] = { "sort_by": "block_number", - "sort_order": sort_order if sort_order in ["asc", "desc"] else "desc", + "sort_order": "desc", "decode": True, } - normalized_chain = normalize_chain_id(chain) if chain is not None else self.chain_ids - if normalized_chain: - params["chain"] = normalized_chain + chain_ids = chain_id if chain_id is not None else self.chain_ids + if chain_ids: + params["chain"] = chain_ids if contract_address: - params["filter_address"] = contract_address - if block_number_gte: - params["filter_block_number_gte"] = block_number_gte - if block_number_lt: - params["filter_block_number_lt"] = block_number_lt + params["filter_address"] = validate_address(contract_address) if transaction_hash: - params["filter_transaction_hash"] = transaction_hash - if topic_0: - params["filter_topic_0"] = topic_0 - if limit: - params["limit"] = limit + params["filter_transaction_hash"] = validate_transaction_hash(transaction_hash) if page: params["page"] = page - return self._get("events", params) + out = self._get("events", params) + out["data"] = filter_response_keys(out["data"], EVENT_KEYS_TO_KEEP) + return out @tool( description="Retrieve events from a specific contract address. Use this to analyze activity or monitor events for a particular smart contract." @@ -80,56 +65,36 @@ def get_contract_events( str, "The contract address to query events for (e.g., '0x1234...'). Must be a valid Ethereum address.", ], - chain: Annotated[ - list[int | str] | int | str | None, + chain_id: Annotated[ + list[int] | int | None, "Chain ID(s) to query (e.g., 1 for Ethereum Mainnet, 137 for Polygon). Specify multiple IDs as a list for cross-chain queries (max 5).", ] = None, - block_number_gte: Annotated[ - int | None, - "Only return events from blocks with number greater than or equal to this value. Useful for querying recent history.", - ] = None, - topic_0: Annotated[ - str | None, - "Filter by event signature hash (first topic). For example, Transfer event has a specific signature hash.", - ] = None, - limit: Annotated[ - int | None, - "Maximum number of events to return per request. Default is 20, increase for more results.", - ] = None, page: Annotated[ int | None, - "Page number for paginated results, starting from 0. Use with limit parameter for browsing large result sets.", + "Page number for paginated results, starting from 0. 20 results are returned per page.", ] = None, - sort_order: Annotated[ - str | None, - "Sort order for the events. Default is 'desc' for descending order. Use 'asc' for ascending order.", - ] = "desc", ) -> dict[str, Any]: params: dict[str, Any] = { "sort_by": "block_number", - "sort_order": sort_order if sort_order in ["asc", "desc"] else "desc", + "sort_order": "desc", "decode": True, } - normalized_chain = normalize_chain_id(chain) if chain is not None else self.chain_ids - if normalized_chain: - params["chain"] = normalized_chain - if block_number_gte: - params["filter_block_number_gte"] = block_number_gte - if topic_0: - params["filter_topic_0"] = topic_0 - if limit: - params["limit"] = limit + chain_ids = chain_id if chain_id is not None else self.chain_ids + if chain_ids: + params["chain"] = chain_ids if page: params["page"] = page + + contract_address = validate_address(contract_address) return self._get(f"events/{contract_address}", params) @tool( - description="Retrieve blockchain transactions with flexible filtering options. Use this to analyze transaction patterns, track specific transactions, or monitor wallet activity." + description="Retrieve blockchain transactions with flexible filtering options. Use this to find transactions from or to an address, or between two addresses." ) - def get_all_transactions( + def get_filtered_transactions( self, - chain: Annotated[ - list[int | str] | int | str | None, + chain_id: Annotated[ + list[int] | int | None, "Chain ID(s) to query (e.g., 1 for Ethereum, 137 for Polygon). Specify multiple IDs as a list for cross-chain queries.", ] = None, from_address: Annotated[ @@ -140,138 +105,53 @@ def get_all_transactions( str | None, "Filter transactions sent to this address (e.g., '0x1234...'). Useful for tracking incoming transactions to a contract or wallet.", ] = None, - function_selector: Annotated[ - str | None, - "Filter by function selector (e.g., '0x095ea7b3' for the approve function). Useful for finding specific contract interactions.", - ] = None, - sort_order: Annotated[ - str | None, - "Sort order for the transactions. Default is 'asc' for ascending order. Use 'desc' for descending order.", - ] = "desc", - limit: Annotated[ - int | None, - "Maximum number of transactions to return per request. Default is 20, adjust based on your needs.", - ] = None, page: Annotated[ int | None, - "Page number for paginated results, starting from 0. Use with limit parameter for browsing large result sets.", + "Page number for paginated results, starting from 0. 20 results are returned per page.", ] = None, ) -> dict[str, Any]: params: dict[str, Any] = { "sort_by": "block_number", - "sort_order": sort_order if sort_order in ["asc", "desc"] else "desc", "decode": True, } - normalized_chain = normalize_chain_id(chain) if chain is not None else self.chain_ids - if normalized_chain: - params["chain"] = normalized_chain + chain_ids = chain_id if chain_id is not None else self.chain_ids + if chain_ids: + params["chain"] = chain_ids if from_address: - params["filter_from_address"] = from_address + params["filter_from_address"] = validate_address(from_address) if to_address: - params["filter_to_address"] = to_address - if function_selector: - params["filter_function_selector"] = function_selector - if limit: - params["limit"] = limit + params["filter_to_address"] = validate_address(to_address) if page: params["page"] = page - return self._get("transactions", params) - @tool( - description="Retrieve ERC20 token balances for a specified address. Lists all fungible tokens owned with their balances, metadata, and optionally prices." - ) - def get_erc20_tokens( - self, - owner_address: Annotated[ - str, - "The wallet address to get ERC20 token balances for (e.g., '0x1234...'). Must be a valid Ethereum address.", - ], - chain: Annotated[ - list[int | str] | int | str | None, - "Chain ID(s) to query (e.g., 1 for Ethereum, 137 for Polygon). Specify multiple IDs as a list for cross-chain queries.", - ] = None, - include_price: Annotated[ - bool | None, - "Set to True to include current market prices for tokens. Useful for calculating portfolio value.", - ] = None, - include_spam: Annotated[ - bool | None, - "Set to True to include suspected spam tokens. Default is False to filter out unwanted tokens.", - ] = None, - ) -> dict[str, Any]: - params: dict[str, Any] = {} - normalized_chain = normalize_chain_id(chain) if chain is not None else self.chain_ids - if normalized_chain: - params["chain"] = normalized_chain - if include_price: - params["include_price"] = include_price - if include_spam: - params["include_spam"] = include_spam - return self._get(f"tokens/erc20/{owner_address}", params) + out = self._get("transactions", params) + out["data"] = filter_response_keys(out["data"], TRANSACTION_KEYS_TO_KEEP) + return out @tool( - description="Retrieve ERC721 NFTs (non-fungible tokens) owned by a specified address. Lists all unique NFTs with their metadata and optionally prices." + description="Retrieve token balances for a specified address. Lists all tokens owned with their balances, metadata, and prices. The default token type is erc20." ) - def get_erc721_tokens( + def get_address_tokens( self, owner_address: Annotated[ str, - "The wallet address to get ERC721 NFTs for (e.g., '0x1234...'). Returns all NFTs owned by this address.", + "The wallet address to get token balances for (e.g., '0x1234...'). Must be a valid Ethereum address.", ], - chain: Annotated[ - list[int | str] | int | str | None, + chain_id: Annotated[ + list[int] | int | None, "Chain ID(s) to query (e.g., 1 for Ethereum, 137 for Polygon). Specify multiple IDs as a list for cross-chain queries.", - ] = None, - include_price: Annotated[ - bool | None, - "Set to True to include estimated prices for NFTs where available. Useful for valuation.", - ] = None, - include_spam: Annotated[ - bool | None, - "Set to True to include suspected spam NFTs. Default is False to filter out potentially unwanted items.", - ] = None, - ) -> dict[str, Any]: - params = {} - normalized_chain = normalize_chain_id(chain) if chain is not None else self.chain_ids - if normalized_chain: - params["chain"] = normalized_chain - if include_price: - params["include_price"] = include_price - if include_spam: - params["include_spam"] = include_spam - return self._get(f"tokens/erc721/{owner_address}", params) - - @tool( - description="Retrieve ERC1155 tokens (semi-fungible tokens) owned by a specified address. Shows balances of multi-token contracts with metadata." - ) - def get_erc1155_tokens( - self, - owner_address: Annotated[ - str, - "The wallet address to get ERC1155 tokens for (e.g., '0x1234...'). Returns all token IDs and their quantities.", ], - chain: Annotated[ - list[int | str] | int | str | None, - "Chain ID(s) to query (e.g., 1 for Ethereum, 137 for Polygon). Specify multiple IDs as a list for cross-chain queries.", - ] = None, - include_price: Annotated[ - bool | None, - "Set to True to include estimated prices for tokens where available. Useful for valuation.", - ] = None, - include_spam: Annotated[ - bool | None, - "Set to True to include suspected spam tokens. Default is False to filter out potentially unwanted items.", - ] = None, + token_type: Annotated[ + Literal["erc20", "erc721", "erc1155"], + "Type of token to query. erc20 means normal tokens, erc721 are NFTs, erc1155 are semi-fungible tokens", + ] = "erc20", ) -> dict[str, Any]: - params = {} - normalized_chain = normalize_chain_id(chain) if chain is not None else self.chain_ids - if normalized_chain: - params["chain"] = normalized_chain - if include_price: - params["include_price"] = include_price - if include_spam: - params["include_spam"] = include_spam - return self._get(f"tokens/erc1155/{owner_address}", params) + params: dict[str, Any] = {"include_price": True, "include_spam": False} + chain_ids = chain_id if chain_id is not None else self.chain_ids + if chain_ids: + params["chain"] = chain_ids + owner_address = validate_address(owner_address) + return self._get(f"tokens/{token_type.lower()}/{owner_address}", params) @tool( description="Get current market prices for native and ERC20 tokens. Useful for valuation, tracking portfolio value, or monitoring price changes." @@ -282,66 +162,64 @@ def get_token_prices( list[str], "List of token contract addresses to get prices for (e.g., ['0x1234...', '0x5678...']). Can include ERC20 tokens. Use '0xeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee' for native tokens (ETH, POL, MATIC, etc.).", ], - chain: Annotated[ - list[int | str] | int | str | None, + chain_id: Annotated[ + list[int] | int | None, "Chain ID(s) where the tokens exist (e.g., 1 for Ethereum, 137 for Polygon). Must match the token network.", ] = None, ) -> dict[str, Any]: + token_addresses = [validate_address(addr) for addr in token_addresses] params: dict[str, Any] = {"address": token_addresses} - normalized_chain = normalize_chain_id(chain) if chain is not None else self.chain_ids - if normalized_chain: - params["chain"] = normalized_chain + chain_ids = chain_id if chain_id is not None else self.chain_ids + if chain_ids: + params["chain"] = chain_ids return self._get("tokens/price", params) @tool( - description="Get contract ABI and metadata about a smart contract, including name, symbol, decimals, and other contract-specific information. Use this when asked about a contract's functions, interface, or capabilities. This tool specifically retrieves details about deployed smart contracts (NOT regular wallet addresses or transaction hashes)." + description="Get metadata about a smart contract, including name, symbol, decimals, and other contract-specific information. Also returns the contract ABI which details how to interact with the contract. Use this when asked about a contract's functions, interface, or capabilities. This tool specifically retrieves details about deployed smart contracts (NOT regular wallet addresses or transaction hashes)." ) def get_contract_metadata( self, contract_address: Annotated[ str, - "The contract address to get metadata for (e.g., '0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2' for WETH). Must be a deployed smart contract address (not a regular wallet). Use this for queries like 'what functions does this contract have' or 'get the ABI for contract 0x1234...'.", + "The contract address to get metadata for (e.g., '0x1234...'). Must be a deployed smart contract address (not a regular wallet). Use this for queries like 'what functions does this contract have' or 'get the ABI for contract 0x1234...'.", ], - chain: Annotated[ - list[int | str] | int | str | None, + chain_id: Annotated[ + list[int] | int | None, "Chain ID(s) where the contract is deployed (e.g., 1 for Ethereum). Specify the correct network.", ] = None, ) -> dict[str, Any]: params = {} - normalized_chain = normalize_chain_id(chain) if chain is not None else self.chain_ids - if normalized_chain: - params["chain"] = normalized_chain + chain_ids = chain_id if chain_id is not None else self.chain_ids + if chain_ids: + params["chain"] = chain_ids + contract_address = validate_address(contract_address) return self._get(f"contracts/metadata/{contract_address}", params) @tool( description="Retrieve detailed information about NFTs from a specific collection, including metadata, attributes, and images. Optionally get data for a specific token ID." ) - def get_nfts( + def get_nfts_from_contract( self, contract_address: Annotated[ str, - "The NFT contract address to query (e.g., '0x1234...'). Must be an ERC721 or ERC1155 contract.", + "The NFT contract address to query (e.g., '0x1234...'). Must be an ERC721 or ERC1155 contract address.", ], token_id: Annotated[ str | None, "Specific token ID to query (e.g., '42'). If provided, returns data only for this NFT. Otherwise returns collection data.", ] = None, - chain: Annotated[ - list[int | str] | int | str | None, + chain_id: Annotated[ + list[int] | int | None, "Chain ID(s) where the NFT contract is deployed (e.g., 1 for Ethereum). Specify the correct network.", ] = None, - include_metadata: Annotated[ - bool | None, - "Set to True to include full NFT metadata like attributes, image URL, etc. Useful for displaying NFT details.", - ] = None, ) -> dict[str, Any]: params = {} - normalized_chain = normalize_chain_id(chain) if chain is not None else self.chain_ids - if normalized_chain: - params["chain"] = normalized_chain - if include_metadata: - params["include_metadata"] = include_metadata + chain_ids = chain_id if chain_id is not None else self.chain_ids + if chain_ids: + params["chain"] = chain_ids + params["include_metadata"] = True + contract_address = validate_address(contract_address) if token_id: return self._get(f"nfts/{contract_address}/{token_id}", params) return self._get(f"nfts/{contract_address}", params) @@ -359,28 +237,23 @@ def get_nft_owners( str | None, "Specific token ID to query owners for (e.g., '42'). If provided, shows all owners of this specific NFT.", ] = None, - chain: Annotated[ - list[int | str] | int | str | None, + chain_id: Annotated[ + list[int] | int | None, "Chain ID(s) where the NFT contract is deployed (e.g., 1 for Ethereum). Specify the correct network.", ] = None, - limit: Annotated[ - int | None, - "Maximum number of ownership records to return per request. Default is 20, adjust for pagination.", - ] = None, page: Annotated[ int | None, - "Page number for paginated results, starting from 0. Use with limit parameter for browsing large collections.", + "Page number for paginated results, starting from 0. 20 results are returned per page.", ] = None, ) -> dict[str, Any]: params = {} - normalized_chain = normalize_chain_id(chain) if chain is not None else self.chain_ids - if normalized_chain: - params["chain"] = normalized_chain - if limit: - params["limit"] = limit + chain_ids = chain_id if chain_id is not None else self.chain_ids + if chain_ids: + params["chain"] = chain_ids if page: params["page"] = page + contract_address = validate_address(contract_address) if token_id: return self._get(f"nfts/owners/{contract_address}/{token_id}", params) return self._get(f"nfts/owners/{contract_address}", params) @@ -398,8 +271,8 @@ def get_nft_transfers( str | None, "Specific token ID to query transfers for (e.g., '42'). If provided, only shows transfers of this NFT.", ] = None, - chain: Annotated[ - list[int | str] | int | str | None, + chain_id: Annotated[ + list[int] | int | None, "Chain ID(s) to query (e.g., 1 for Ethereum). Specify the chain where the NFT contract is deployed.", ] = None, limit: Annotated[ @@ -412,14 +285,15 @@ def get_nft_transfers( ] = None, ) -> dict[str, Any]: params = {} - normalized_chain = normalize_chain_id(chain) if chain is not None else self.chain_ids - if normalized_chain: - params["chain"] = normalized_chain + chain_ids = chain_id if chain_id is not None else self.chain_ids + if chain_ids: + params["chain"] = chain_ids if limit: params["limit"] = limit if page: params["page"] = page + contract_address = validate_address(contract_address) if token_id: return self._get(f"nfts/transfers/{contract_address}/{token_id}", params) return self._get(f"nfts/transfers/{contract_address}", params) @@ -433,15 +307,17 @@ def get_block_details( str, "Block number or block hash to look up. Can be either a simple number (e.g., '12345678') or a block hash (e.g., '0xd4e56740f876aef8c010b86a40d5f56745a118d0906a34e69aec8c0db1cb8fa3' for Ethereum block 0). Use for queries like 'what happened in block 14000000' or 'show me block 0xd4e56...'.", ], - chain: Annotated[ - list[int | str] | int | str | None, + chain_id: Annotated[ + list[int] | int | None, "Chain ID(s) to query (e.g., 1 for Ethereum). Specify the blockchain network where the block exists.", ] = None, ) -> dict[str, Any]: params = {} - normalized_chain = normalize_chain_id(chain) if chain is not None else self.chain_ids - if normalized_chain: - params["chain"] = normalized_chain + chain_ids = chain_id if chain_id is not None else self.chain_ids + if chain_ids: + params["chain"] = chain_ids + + block_identifier = validate_block_identifier(block_identifier) out = self._get(f"resolve/{block_identifier}", params) return clean_resolve(out) @@ -452,18 +328,19 @@ def get_address_transactions( self, address: Annotated[ str, - "Wallet or contract address to look up (e.g., '0xd8dA6BF26964aF9D7eEd9e03E53415D37aA96045' for Vitalik's address). Must be a valid blockchain address starting with 0x and 42 characters long.", + "Wallet or contract address to look up (e.g., '0x1234...'). Must be a valid blockchain address starting with 0x and 42 characters long.", ], - chain: Annotated[ - list[int | str] | int | str | None, + chain_id: Annotated[ + list[int] | int | None, "Chain ID(s) to query (e.g., 1 for Ethereum). Specify the blockchain network for the address.", ] = None, ) -> dict[str, Any]: params = {} - normalized_chain = normalize_chain_id(chain) if chain is not None else self.chain_ids - if normalized_chain: - params["chain"] = normalized_chain - out = self._get(f"resolve/{address}", params) + chain_ids = chain_id if chain_id is not None else self.chain_ids + if chain_ids: + params["chain"] = chain_ids + validated_address = validate_address(address) + out = self._get(f"resolve/{validated_address}", params) return clean_resolve(out) @tool( @@ -475,15 +352,15 @@ def get_ens_transactions( str, "ENS name to resolve (e.g., 'vitalik.eth', 'thirdweb.eth'). Must be a valid ENS domain ending with .eth.", ], - chain: Annotated[ - list[int | str] | int | str | None, + chain_id: Annotated[ + list[int] | int | None, "Chain ID(s) to query (e.g., 1 for Ethereum). ENS is primarily on Ethereum mainnet.", ] = None, ) -> dict[str, Any]: params = {} - normalized_chain = normalize_chain_id(chain) if chain is not None else self.chain_ids - if normalized_chain: - params["chain"] = normalized_chain + chain_ids = chain_id if chain_id is not None else self.chain_ids + if chain_ids: + params["chain"] = chain_ids out = self._get(f"resolve/{ens_name}", params) return clean_resolve(out) @@ -496,15 +373,17 @@ def get_transaction_details( str, "Transaction hash to look up (e.g., '0x5407ea41de24b7353d70eab42d72c92b42a44e140f930e349973cfc7b8c9c1d7'). Must be a valid transaction hash beginning with 0x and typically 66 characters long. Use this for queries like 'tell me about this transaction' or 'what happened in transaction 0x1234...'.", ], - chain: Annotated[ - list[int | str] | int | str | None, + chain_id: Annotated[ + list[int] | int | None, "Chain ID(s) to query (e.g., 1 for Ethereum). Specify the blockchain network where the transaction exists.", ] = None, ) -> dict[str, Any]: params = {} - normalized_chain = normalize_chain_id(chain) if chain is not None else self.chain_ids - if normalized_chain: - params["chain"] = normalized_chain + chain_ids = chain_id if chain_id is not None else self.chain_ids + if chain_ids: + params["chain"] = chain_ids + + transaction_hash = validate_transaction_hash(transaction_hash) out = self._get(f"resolve/{transaction_hash}", params) return clean_resolve(out) @@ -517,14 +396,16 @@ def decode_signature( str, "Function or event signature to decode (e.g., '0x095ea7b3' for the approve function). Usually begins with 0x.", ], - chain: Annotated[ - list[int | str] | int | str | None, + chain_id: Annotated[ + list[int] | int | None, "Chain ID(s) to query (e.g., 1 for Ethereum). Specify to improve signature lookup accuracy.", ] = None, ) -> dict[str, Any]: params = {} - normalized_chain = normalize_chain_id(chain) if chain is not None else self.chain_ids - if normalized_chain: - params["chain"] = normalized_chain + chain_ids = chain_id if chain_id is not None else self.chain_ids + if chain_ids: + params["chain"] = chain_ids + + signature = validate_signature(signature) out = self._get(f"resolve/{signature}", params) return clean_resolve(out) diff --git a/python/thirdweb-ai/src/thirdweb_ai/services/nebula.py b/python/thirdweb-ai/src/thirdweb_ai/services/nebula.py index 272d013..f3f070d 100644 --- a/python/thirdweb-ai/src/thirdweb_ai/services/nebula.py +++ b/python/thirdweb-ai/src/thirdweb_ai/services/nebula.py @@ -5,8 +5,10 @@ class Nebula(Service): - def __init__(self, secret_key: str): + def __init__(self, secret_key: str, response_format: dict[str, int | str] | None = None) -> None: super().__init__(base_url="https://nebula-api.thirdweb.com", secret_key=secret_key) + if response_format: + self.response_format = response_format @tool( description="Send a message to Nebula AI and get a response. This can be used for blockchain queries, contract interactions, and access to thirdweb tools." @@ -31,6 +33,8 @@ def chat( data["session_id"] = session_id if context: data["context"] = context + if self.response_format: + data["response_format"] = self.response_format return self._post("/chat", data) diff --git a/python/thirdweb-ai/src/thirdweb_ai/services/service.py b/python/thirdweb-ai/src/thirdweb_ai/services/service.py index 7f19ec8..9046051 100644 --- a/python/thirdweb-ai/src/thirdweb_ai/services/service.py +++ b/python/thirdweb-ai/src/thirdweb_ai/services/service.py @@ -11,9 +11,7 @@ class Service: def __init__(self, base_url: str, secret_key: str, httpx_client: httpx.Client | None = None): self.base_url = base_url self.secret_key = secret_key - self.client = ( - httpx_client or httpx.Client(timeout=120.0, transport=httpx.HTTPTransport(retries=5)) - ) + self.client = httpx_client or httpx.Client(timeout=120.0, transport=httpx.HTTPTransport(retries=5)) def _make_headers(self): kwargs = {"Content-Type": "application/json"} diff --git a/python/thirdweb-ai/src/thirdweb_ai/services/storage.py b/python/thirdweb-ai/src/thirdweb_ai/services/storage.py index 6a1852f..c9b9e16 100644 --- a/python/thirdweb-ai/src/thirdweb_ai/services/storage.py +++ b/python/thirdweb-ai/src/thirdweb_ai/services/storage.py @@ -50,21 +50,21 @@ def _get_file(self, path: str, params: dict[str, Any] | None = None, headers: di _headers.update(headers) response = self.client.get(path, params=params, headers=_headers) response.raise_for_status() - + content_type = response.headers.get("Content-Type", "") - + # Handle JSON responses if "application/json" in content_type: return response.json() - + # Handle binary files (images, pdfs, etc) if content_type.startswith(("image/", "application/pdf", "application/octet-stream")): return {"content": response.content, "content_type": content_type} - + # Handle text content (html, plain text, etc) if content_type.startswith(("text/", "application/xml")): return {"content": response.text, "content_type": content_type} - + # Default fallback - try json first, then return content with type try: return response.json() @@ -101,9 +101,7 @@ def _is_valid_path(self, path: str) -> bool: """Check if the string is a valid file or directory path.""" return Path(path).exists() - def _prepare_directory_files( - self, directory_path: Path, chunk_size: int = 8192 - ) -> list[tuple[str, BytesIO, str]]: + def _prepare_directory_files(self, directory_path: Path, chunk_size: int = 8192) -> list[tuple[str, BytesIO, str]]: """ Prepare files from a directory for upload, preserving directory structure. Returns a list of tuples (relative_path, file_buffer, content_type). diff --git a/python/thirdweb-ai/tests/common/test_utils.py b/python/thirdweb-ai/tests/common/test_utils.py index 09a20e6..96d3728 100644 --- a/python/thirdweb-ai/tests/common/test_utils.py +++ b/python/thirdweb-ai/tests/common/test_utils.py @@ -41,4 +41,3 @@ def test_invalid_digit_string(self): # doesn't trigger this error case since re.search('\d+') always # returns a valid digit string if it matches pass - diff --git a/python/thirdweb-ai/uv.lock b/python/thirdweb-ai/uv.lock index 52f5cbc..00c9fa0 100644 --- a/python/thirdweb-ai/uv.lock +++ b/python/thirdweb-ai/uv.lock @@ -3426,6 +3426,7 @@ dependencies = [ { name = "httpx" }, { name = "jsonref" }, { name = "pydantic" }, + { name = "web3" }, ] [package.optional-dependencies] @@ -3504,6 +3505,7 @@ requires-dist = [ { name = "pydantic-ai", marker = "extra == 'pydantic-ai'", specifier = ">=0.0.39" }, { name = "smolagents", marker = "extra == 'all'", specifier = ">=1.10.0" }, { name = "smolagents", marker = "extra == 'smolagents'", specifier = ">=1.10.0" }, + { name = "web3", specifier = ">=7.9.0" }, ] provides-extras = ["all", "langchain", "goat", "openai", "autogen", "llama-index", "agentkit", "mcp", "smolagents", "pydantic-ai"] diff --git a/python/thirdweb-mcp/pyproject.toml b/python/thirdweb-mcp/pyproject.toml index 66e786a..5c80f7c 100644 --- a/python/thirdweb-mcp/pyproject.toml +++ b/python/thirdweb-mcp/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "thirdweb-mcp" -version = "0.1.13" +version = "0.1.14" description = "thirdweb MCP" authors = [{ name = "thirdweb", email = "support@thirdweb.com" }] requires-python = "~=3.10" @@ -9,7 +9,7 @@ license = "Apache-2.0" dependencies = [ "mcp>=1.3.0,<2", "click>=8.1.8,<9", - "thirdweb-ai[mcp]==0.1.9", + "thirdweb-ai[mcp]==0.1.10", ] [project.scripts]