diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 3015aa9fbd..f48e388e13 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -22,6 +22,7 @@ TypeVar, Union, cast, + overload, ) from pydantic import BaseConfig, BaseModel @@ -87,6 +88,28 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: "Passing sa_column_kwargs is not supported when " "also passing a sa_column" ) + if primary_key is not Undefined: + raise RuntimeError( + "Passing primary_key is not supported when " + "also passing a sa_column" + ) + if nullable is not Undefined: + raise RuntimeError( + "Passing nullable is not supported when " "also passing a sa_column" + ) + if foreign_key is not Undefined: + raise RuntimeError( + "Passing foreign_key is not supported when " + "also passing a sa_column" + ) + if unique is not Undefined: + raise RuntimeError( + "Passing unique is not supported when " "also passing a sa_column" + ) + if index is not Undefined: + raise RuntimeError( + "Passing index is not supported when " "also passing a sa_column" + ) super().__init__(default=default, **kwargs) self.primary_key = primary_key self.nullable = nullable @@ -126,6 +149,86 @@ def __init__( self.sa_relationship_kwargs = sa_relationship_kwargs +@overload +def Field( + default: Any = Undefined, + *, + default_factory: Optional[NoArgAnyCallable] = None, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + exclude: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + include: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + const: Optional[bool] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + multiple_of: Optional[float] = None, + max_digits: Optional[int] = None, + decimal_places: Optional[int] = None, + min_items: Optional[int] = None, + max_items: Optional[int] = None, + unique_items: Optional[bool] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + allow_mutation: bool = True, + regex: Optional[str] = None, + discriminator: Optional[str] = None, + repr: bool = True, + primary_key: Union[bool, UndefinedType] = Undefined, + foreign_key: Any = Undefined, + unique: Union[bool, UndefinedType] = Undefined, + nullable: Union[bool, UndefinedType] = Undefined, + index: Union[bool, UndefinedType] = Undefined, + sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, + sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, + schema_extra: Optional[Dict[str, Any]] = None, +) -> Any: + ... + + +@overload +def Field( + default: Any = Undefined, + *, + default_factory: Optional[NoArgAnyCallable] = None, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + exclude: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + include: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + const: Optional[bool] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + multiple_of: Optional[float] = None, + max_digits: Optional[int] = None, + decimal_places: Optional[int] = None, + min_items: Optional[int] = None, + max_items: Optional[int] = None, + unique_items: Optional[bool] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + allow_mutation: bool = True, + regex: Optional[str] = None, + discriminator: Optional[str] = None, + repr: bool = True, + sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore + schema_extra: Optional[Dict[str, Any]] = None, +) -> Any: + ... + + def Field( default: Any = Undefined, *, @@ -156,9 +259,9 @@ def Field( regex: Optional[str] = None, discriminator: Optional[str] = None, repr: bool = True, - primary_key: bool = False, - foreign_key: Optional[Any] = None, - unique: bool = False, + primary_key: Union[bool, UndefinedType] = Undefined, + foreign_key: Any = Undefined, + unique: Union[bool, UndefinedType] = Undefined, nullable: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore @@ -206,6 +309,27 @@ def Field( return field_info +@overload +def Relationship( + *, + back_populates: Optional[str] = None, + link_model: Optional[Any] = None, + sa_relationship_args: Optional[Sequence[Any]] = None, + sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, +) -> Any: + ... + + +@overload +def Relationship( + *, + back_populates: Optional[str] = None, + link_model: Optional[Any] = None, + sa_relationship: Optional[RelationshipProperty] = None, # type: ignore +) -> Any: + ... + + def Relationship( *, back_populates: Optional[str] = None, @@ -440,21 +564,28 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore if isinstance(sa_column, Column): return sa_column sa_type = get_sqlalchemy_type(field) - primary_key = getattr(field.field_info, "primary_key", False) + primary_key = getattr(field.field_info, "primary_key", Undefined) + if primary_key is Undefined: + primary_key = False index = getattr(field.field_info, "index", Undefined) if index is Undefined: index = False nullable = not primary_key and _is_field_noneable(field) # Override derived nullability if the nullable property is set explicitly # on the field - if hasattr(field.field_info, "nullable"): - field_nullable = getattr(field.field_info, "nullable") # noqa: B009 - if field_nullable != Undefined: - nullable = field_nullable + field_nullable = getattr(field.field_info, "nullable", Undefined) # noqa: B009 + if field_nullable != Undefined: + assert not isinstance(field_nullable, UndefinedType) + nullable = field_nullable args = [] - foreign_key = getattr(field.field_info, "foreign_key", None) - unique = getattr(field.field_info, "unique", False) + foreign_key = getattr(field.field_info, "foreign_key", Undefined) + if foreign_key is Undefined: + foreign_key = None + unique = getattr(field.field_info, "unique", Undefined) + if unique is Undefined: + unique = False if foreign_key: + assert isinstance(foreign_key, str) args.append(ForeignKey(foreign_key)) kwargs = { "primary_key": primary_key, diff --git a/tests/test_field_sa_args_kwargs.py b/tests/test_field_sa_args_kwargs.py new file mode 100644 index 0000000000..94a1a13483 --- /dev/null +++ b/tests/test_field_sa_args_kwargs.py @@ -0,0 +1,39 @@ +from typing import Optional + +from sqlalchemy import ForeignKey +from sqlmodel import Field, SQLModel, create_engine + + +def test_sa_column_args(clear_sqlmodel, caplog) -> None: + class Team(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + team_id: Optional[int] = Field( + default=None, + sa_column_args=[ForeignKey("team.id")], + ) + + engine = create_engine("sqlite://", echo=True) + SQLModel.metadata.create_all(engine) + create_table_log = [ + message for message in caplog.messages if "CREATE TABLE hero" in message + ][0] + assert "FOREIGN KEY(team_id) REFERENCES team (id)" in create_table_log + + +def test_sa_column_kargs(clear_sqlmodel, caplog) -> None: + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + sa_column_kwargs={"primary_key": True}, + ) + + engine = create_engine("sqlite://", echo=True) + SQLModel.metadata.create_all(engine) + create_table_log = [ + message for message in caplog.messages if "CREATE TABLE item" in message + ][0] + assert "PRIMARY KEY (id)" in create_table_log diff --git a/tests/test_field_sa_column.py b/tests/test_field_sa_column.py new file mode 100644 index 0000000000..51cfdfa797 --- /dev/null +++ b/tests/test_field_sa_column.py @@ -0,0 +1,99 @@ +from typing import Optional + +import pytest +from sqlalchemy import Column, Integer, String +from sqlmodel import Field, SQLModel + + +def test_sa_column_takes_precedence() -> None: + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + sa_column=Column(String, primary_key=True, nullable=False), + ) + + # It would have been nullable with no sa_column + assert Item.id.nullable is False # type: ignore + assert isinstance(Item.id.type, String) # type: ignore + + +def test_sa_column_no_sa_args() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + sa_column_args=[Integer], + sa_column=Column(Integer, primary_key=True), + ) + + +def test_sa_column_no_sa_kargs() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + sa_column_kwargs={"primary_key": True}, + sa_column=Column(Integer, primary_key=True), + ) + + +def test_sa_column_no_primary_key() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + primary_key=True, + sa_column=Column(Integer, primary_key=True), + ) + + +def test_sa_column_no_nullable() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + nullable=True, + sa_column=Column(Integer, primary_key=True), + ) + + +def test_sa_column_no_foreign_key() -> None: + with pytest.raises(RuntimeError): + + class Team(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + team_id: Optional[int] = Field( + default=None, + foreign_key="team.id", + sa_column=Column(Integer, primary_key=True), + ) + + +def test_sa_column_no_unique() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + unique=True, + sa_column=Column(Integer, primary_key=True), + ) + + +def test_sa_column_no_index() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + index=True, + sa_column=Column(Integer, primary_key=True), + ) diff --git a/tests/test_field_sa_relationship.py b/tests/test_field_sa_relationship.py new file mode 100644 index 0000000000..7606fd86d8 --- /dev/null +++ b/tests/test_field_sa_relationship.py @@ -0,0 +1,53 @@ +from typing import List, Optional + +import pytest +from sqlalchemy.orm import relationship +from sqlmodel import Field, Relationship, SQLModel + + +def test_sa_relationship_no_args() -> None: + with pytest.raises(RuntimeError): + + class Team(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + headquarters: str + + heroes: List["Hero"] = Relationship( + back_populates="team", + sa_relationship_args=["Hero"], + sa_relationship=relationship("Hero", back_populates="team"), + ) + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + secret_name: str + age: Optional[int] = Field(default=None, index=True) + + team_id: Optional[int] = Field(default=None, foreign_key="team.id") + team: Optional[Team] = Relationship(back_populates="heroes") + + +def test_sa_relationship_no_kwargs() -> None: + with pytest.raises(RuntimeError): + + class Team(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + headquarters: str + + heroes: List["Hero"] = Relationship( + back_populates="team", + sa_relationship_kwargs={"lazy": "selectin"}, + sa_relationship=relationship("Hero", back_populates="team"), + ) + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + secret_name: str + age: Optional[int] = Field(default=None, index=True) + + team_id: Optional[int] = Field(default=None, foreign_key="team.id") + team: Optional[Team] = Relationship(back_populates="heroes")