Skip to content

add various improvements for tools #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/thirdweb-ai/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "thirdweb-ai"
version = "0.1.9"
version = "0.1.10"
description = "thirdweb AI"
authors = [{ name = "thirdweb", email = "[email protected]" }]
requires-python = ">=3.10,<4.0"
Expand All @@ -20,6 +20,7 @@ dependencies = [
"jsonref>=1.1.0,<2",
"httpx>=0.28.1,<0.29",
"aiohttp>=3.11.14",
"web3>=7.9.0",
]

[project.optional-dependencies]
Expand Down
52 changes: 52 additions & 0 deletions python/thirdweb-ai/src/thirdweb_ai/common/address.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import re

from web3 import Web3


def validate_address(address: str) -> str:
if not address.startswith("0x") or len(address) != 42:
raise ValueError(f"Invalid blockchain address format: {address}")

if not Web3.is_checksum_address(address):
try:
return Web3.to_checksum_address(address)
except ValueError as e:
raise ValueError(f"Invalid blockchain address: {address}") from e

return address


def validate_transaction_hash(tx_hash: str) -> str:
pattern = re.compile(r"^0x[a-fA-F0-9]{64}$")
if bool(re.fullmatch(pattern, tx_hash)):
return tx_hash
raise ValueError(f"Invalid transaction hash: {tx_hash}")


def validate_block_identifier(block_id: str) -> str:
if block_id.startswith("0x"):
pattern = re.compile(r"^0x[a-fA-F0-9]{64}$")
if bool(re.fullmatch(pattern, block_id)):
return block_id
elif block_id.isdigit():
return block_id

raise ValueError(f"Invalid block identifier: {block_id}")


def validate_signature(signature: str) -> str:
# Function selector (4 bytes)
if signature.startswith("0x") and len(signature) == 10:
pattern = re.compile(r"^0x[a-fA-F0-9]{8}$")
if bool(re.fullmatch(pattern, signature)):
return signature
# Event topic (32 bytes)
elif signature.startswith("0x") and len(signature) == 66:
pattern = re.compile(r"^0x[a-fA-F0-9]{64}$")
if bool(re.fullmatch(pattern, signature)):
return signature
# Plain text signature (e.g. "transfer(address,uint256)")
elif "(" in signature and ")" in signature:
return signature

raise ValueError(f"Invalid function or event signature: {signature}")
53 changes: 39 additions & 14 deletions python/thirdweb-ai/src/thirdweb_ai/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,33 @@
import re
from typing import Any

TRANSACTION_KEYS_TO_KEEP = [
"hash",
"block_number",
"block_timestamp",
"from_address",
"to_address",
"value",
"decodedData",
]
EVENT_KEYS_TO_KEEP = [
"block_number",
"block_timestamp",
"address",
"transaction_hash",
"transaction_index",
"log_index",
"topics",
"data",
"decodedData",
]


def extract_digits(value: int | str) -> int:
"""Extract the integer value from a string or return the integer directly."""
if isinstance(value, int):
return value

value_str = str(value).strip("\"'")
digit_match = re.search(r"\d+", value_str)

Expand All @@ -16,21 +41,8 @@ def extract_digits(value: int | str) -> int:
return int(extracted_digits)


def normalize_chain_id(
in_value: int | str | list[int | str] | None,
) -> int | list[int] | None:
"""Normalize str values integers."""

if in_value is None:
return None

if isinstance(in_value, list):
return [extract_digits(c) for c in in_value]

return extract_digits(in_value)


def is_encoded(encoded_data: str) -> bool:
"""Check if a string is a valid hexadecimal value."""
encoded_data = encoded_data.removeprefix("0x")

try:
Expand All @@ -41,10 +53,23 @@ def is_encoded(encoded_data: str) -> bool:


def clean_resolve(out: dict[str, Any]):
"""Clean the response from the resolve function."""
if "transactions" in out["data"]:
for transaction in out["data"]["transactions"]:
if "data" in transaction and is_encoded(transaction["data"]):
transaction.pop("data")
if "logs_bloom" in transaction:
transaction.pop("logs_bloom")
return out


def filter_response_keys(items: list[dict[str, Any]], keys_to_keep: list[str] | None) -> list[dict[str, Any]]:
"""Filter the response items to only include the specified keys"""
if not keys_to_keep:
return items

for item in items:
keys_to_remove = [key for key in item if key not in keys_to_keep]
for key in keys_to_remove:
item.pop(key, None)
return items
50 changes: 33 additions & 17 deletions python/thirdweb-ai/src/thirdweb_ai/services/engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Annotated, Any
from typing import Annotated, Any, Literal

from thirdweb_ai.common.utils import extract_digits, normalize_chain_id
from thirdweb_ai.common.utils import extract_digits
from thirdweb_ai.services.service import Service
from thirdweb_ai.tools.tool import tool

Expand All @@ -10,15 +10,15 @@ def __init__(
self,
engine_url: str,
engine_auth_jwt: str,
chain_id: int | str | None = None,
chain_id: int | None = None,
backend_wallet_address: str | None = None,
secret_key: str = "",
):
super().__init__(base_url=engine_url, secret_key=secret_key)
self.engine_url = engine_url
self.engine_auth_jwt = engine_auth_jwt
self.backend_wallet_address = backend_wallet_address
self.chain_id = normalize_chain_id(chain_id)
self.chain_id = chain_id

def _make_headers(self):
headers = super()._make_headers()
Expand All @@ -34,7 +34,7 @@ def _make_headers(self):
def create_backend_wallet(
self,
wallet_type: Annotated[
str,
Literal["local", "smart:local"],
"The type of backend wallet to create. Currently supported options are 'local' (stored locally in Engine's database) or 'smart:local' (for smart account wallets with advanced features). Choose 'local' for standard EOA wallets, and 'smart:local' for smart contract wallets with batching capabilities.",
],
label: Annotated[
Expand Down Expand Up @@ -76,7 +76,7 @@ def get_all_backend_wallet(
def get_wallet_balance(
self,
chain_id: Annotated[
str | int,
int | None,
"The numeric blockchain network ID to query (e.g., '1' for Ethereum mainnet, '137' for Polygon). If not provided, uses the default chain ID configured in the Engine instance.",
],
backend_wallet_address: Annotated[
Expand All @@ -85,9 +85,13 @@ def get_wallet_balance(
] = None,
) -> dict[str, Any]:
"""Get wallet balance for native or ERC20 tokens."""
normalized_chain = normalize_chain_id(chain_id) or self.chain_id
if self.chain_id is not None and chain_id is None:
chain_id = self.chain_id
elif chain_id is None:
raise ValueError("chain_id is required")

backend_wallet_address = backend_wallet_address or self.backend_wallet_address
return self._get(f"backend-wallet/{normalized_chain}/{backend_wallet_address}/get-balance")
return self._get(f"backend-wallet/{chain_id}/{backend_wallet_address}/get-balance")

@tool(
description="Send an on-chain transaction. This powerful function can transfer native currency (ETH, MATIC), ERC20 tokens, or execute any arbitrary contract interaction. The transaction is signed and broadcast to the blockchain automatically."
Expand All @@ -107,7 +111,7 @@ def send_transaction(
"The hexadecimal transaction data payload for contract interactions (e.g., '0x23b872dd...'). For simple native currency transfers, leave this empty. For ERC20 transfers or contract calls, this contains the ABI-encoded function call.",
],
chain_id: Annotated[
str | int,
int | None,
"The numeric blockchain network ID to send the transaction on (e.g., '1' for Ethereum mainnet, '137' for Polygon). If not provided, uses the default chain ID configured in the Engine instance.",
],
backend_wallet_address: Annotated[
Expand All @@ -126,10 +130,14 @@ def send_transaction(
"data": data or "0x",
}

normalized_chain = normalize_chain_id(chain_id) or self.chain_id
if self.chain_id is not None and chain_id is None:
chain_id = self.chain_id
elif chain_id is None:
raise ValueError("chain_id is required")

backend_wallet_address = backend_wallet_address or self.backend_wallet_address
return self._post(
f"backend-wallet/{normalized_chain}/send-transaction",
f"backend-wallet/{chain_id}/send-transaction",
payload,
headers={"X-Backend-Wallet-Address": backend_wallet_address},
)
Expand Down Expand Up @@ -165,7 +173,7 @@ def read_contract(
"An ordered list of arguments to pass to the function (e.g., [address, tokenId]). Must match the types and order expected by the function. For functions with no parameters, use an empty list or None.",
],
chain_id: Annotated[
str | int,
int | None,
"The numeric blockchain network ID where the contract is deployed (e.g., '1' for Ethereum mainnet, '137' for Polygon). If not provided, uses the default chain ID configured in the Engine instance.",
],
) -> dict[str, Any]:
Expand All @@ -174,8 +182,12 @@ def read_contract(
"functionName": function_name,
"args": function_args or [],
}
normalized_chain = normalize_chain_id(chain_id) or self.chain_id
return self._get(f"contract/{normalized_chain}/{contract_address}/read", payload)
if self.chain_id is not None and chain_id is None:
chain_id = self.chain_id
elif chain_id is None:
raise ValueError("chain_id is required")

return self._get(f"contract/{chain_id}/{contract_address}/read", payload)

@tool(
description="Execute a state-changing function on a smart contract by sending a transaction. This allows you to modify on-chain data, such as transferring tokens, minting NFTs, or updating contract configuration. The transaction is automatically signed by your backend wallet and submitted to the blockchain."
Expand All @@ -199,7 +211,7 @@ def write_contract(
"The amount of native currency (ETH, MATIC, etc.) to send with the transaction, in wei (e.g., '1000000000000000000' for 1 ETH). Required for payable functions, use '0' for non-payable functions. Default to '0'.",
],
chain_id: Annotated[
str | int,
int | None,
"The numeric blockchain network ID where the contract is deployed (e.g., '1' for Ethereum mainnet, '137' for Polygon). If not provided, uses the default chain ID configured in the Engine instance.",
],
) -> dict[str, Any]:
Expand All @@ -212,9 +224,13 @@ def write_contract(
if value and value != "0":
payload["txOverrides"] = {"value": value}

normalized_chain = normalize_chain_id(chain_id) or self.chain_id
if self.chain_id is not None and chain_id is None:
chain_id = self.chain_id
elif chain_id is None:
raise ValueError("chain_id is required")

return self._post(
f"contract/{normalized_chain}/{contract_address}/write",
f"contract/{chain_id}/{contract_address}/write",
payload,
headers={"X-Backend-Wallet-Address": self.backend_wallet_address},
)
Loading