Skip to content

✨ Do not allow invalid combinations of field parameters for columns and relationships, sa_column excludes sa_column_args, primary_key, nullable, etc. #681

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 28, 2023
151 changes: 141 additions & 10 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
TypeVar,
Union,
cast,
overload,
)

from pydantic import BaseConfig, BaseModel
Expand Down Expand Up @@ -87,6 +88,28 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
"Passing sa_column_kwargs is not supported when "
"also passing a sa_column"
)
if primary_key is not Undefined:
raise RuntimeError(
"Passing primary_key is not supported when "
"also passing a sa_column"
)
if nullable is not Undefined:
raise RuntimeError(
"Passing nullable is not supported when " "also passing a sa_column"
)
if foreign_key is not Undefined:
raise RuntimeError(
"Passing foreign_key is not supported when "
"also passing a sa_column"
)
if unique is not Undefined:
raise RuntimeError(
"Passing unique is not supported when " "also passing a sa_column"
)
if index is not Undefined:
raise RuntimeError(
"Passing index is not supported when " "also passing a sa_column"
)
super().__init__(default=default, **kwargs)
self.primary_key = primary_key
self.nullable = nullable
Expand Down Expand Up @@ -126,6 +149,86 @@ def __init__(
self.sa_relationship_kwargs = sa_relationship_kwargs


@overload
def Field(
default: Any = Undefined,
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
include: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
const: Optional[bool] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
multiple_of: Optional[float] = None,
max_digits: Optional[int] = None,
decimal_places: Optional[int] = None,
min_items: Optional[int] = None,
max_items: Optional[int] = None,
unique_items: Optional[bool] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
discriminator: Optional[str] = None,
repr: bool = True,
primary_key: Union[bool, UndefinedType] = Undefined,
foreign_key: Any = Undefined,
unique: Union[bool, UndefinedType] = Undefined,
nullable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
...


@overload
def Field(
default: Any = Undefined,
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
include: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
const: Optional[bool] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
multiple_of: Optional[float] = None,
max_digits: Optional[int] = None,
decimal_places: Optional[int] = None,
min_items: Optional[int] = None,
max_items: Optional[int] = None,
unique_items: Optional[bool] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
discriminator: Optional[str] = None,
repr: bool = True,
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
...


def Field(
default: Any = Undefined,
*,
Expand Down Expand Up @@ -156,9 +259,9 @@ def Field(
regex: Optional[str] = None,
discriminator: Optional[str] = None,
repr: bool = True,
primary_key: bool = False,
foreign_key: Optional[Any] = None,
unique: bool = False,
primary_key: Union[bool, UndefinedType] = Undefined,
foreign_key: Any = Undefined,
unique: Union[bool, UndefinedType] = Undefined,
nullable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
Expand Down Expand Up @@ -206,6 +309,27 @@ def Field(
return field_info


@overload
def Relationship(
*,
back_populates: Optional[str] = None,
link_model: Optional[Any] = None,
sa_relationship_args: Optional[Sequence[Any]] = None,
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
) -> Any:
...


@overload
def Relationship(
*,
back_populates: Optional[str] = None,
link_model: Optional[Any] = None,
sa_relationship: Optional[RelationshipProperty] = None, # type: ignore
) -> Any:
...


def Relationship(
*,
back_populates: Optional[str] = None,
Expand Down Expand Up @@ -440,21 +564,28 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore
if isinstance(sa_column, Column):
return sa_column
sa_type = get_sqlalchemy_type(field)
primary_key = getattr(field.field_info, "primary_key", False)
primary_key = getattr(field.field_info, "primary_key", Undefined)
if primary_key is Undefined:
primary_key = False
index = getattr(field.field_info, "index", Undefined)
if index is Undefined:
index = False
nullable = not primary_key and _is_field_noneable(field)
# Override derived nullability if the nullable property is set explicitly
# on the field
if hasattr(field.field_info, "nullable"):
field_nullable = getattr(field.field_info, "nullable") # noqa: B009
if field_nullable != Undefined:
nullable = field_nullable
field_nullable = getattr(field.field_info, "nullable", Undefined) # noqa: B009
if field_nullable != Undefined:
assert not isinstance(field_nullable, UndefinedType)
nullable = field_nullable
args = []
foreign_key = getattr(field.field_info, "foreign_key", None)
unique = getattr(field.field_info, "unique", False)
foreign_key = getattr(field.field_info, "foreign_key", Undefined)
if foreign_key is Undefined:
foreign_key = None
unique = getattr(field.field_info, "unique", Undefined)
if unique is Undefined:
unique = False
if foreign_key:
assert isinstance(foreign_key, str)
args.append(ForeignKey(foreign_key))
kwargs = {
"primary_key": primary_key,
Expand Down
39 changes: 39 additions & 0 deletions tests/test_field_sa_args_kwargs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Optional

from sqlalchemy import ForeignKey
from sqlmodel import Field, SQLModel, create_engine


def test_sa_column_args(clear_sqlmodel, caplog) -> None:
class Team(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str

class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
team_id: Optional[int] = Field(
default=None,
sa_column_args=[ForeignKey("team.id")],
)

engine = create_engine("sqlite://", echo=True)
SQLModel.metadata.create_all(engine)
create_table_log = [
message for message in caplog.messages if "CREATE TABLE hero" in message
][0]
assert "FOREIGN KEY(team_id) REFERENCES team (id)" in create_table_log


def test_sa_column_kargs(clear_sqlmodel, caplog) -> None:
class Item(SQLModel, table=True):
id: Optional[int] = Field(
default=None,
sa_column_kwargs={"primary_key": True},
)

engine = create_engine("sqlite://", echo=True)
SQLModel.metadata.create_all(engine)
create_table_log = [
message for message in caplog.messages if "CREATE TABLE item" in message
][0]
assert "PRIMARY KEY (id)" in create_table_log
99 changes: 99 additions & 0 deletions tests/test_field_sa_column.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Optional

import pytest
from sqlalchemy import Column, Integer, String
from sqlmodel import Field, SQLModel


def test_sa_column_takes_precedence() -> None:
class Item(SQLModel, table=True):
id: Optional[int] = Field(
default=None,
sa_column=Column(String, primary_key=True, nullable=False),
)

# It would have been nullable with no sa_column
assert Item.id.nullable is False # type: ignore
assert isinstance(Item.id.type, String) # type: ignore


def test_sa_column_no_sa_args() -> None:
with pytest.raises(RuntimeError):

class Item(SQLModel, table=True):
id: Optional[int] = Field(
default=None,
sa_column_args=[Integer],
sa_column=Column(Integer, primary_key=True),
)


def test_sa_column_no_sa_kargs() -> None:
with pytest.raises(RuntimeError):

class Item(SQLModel, table=True):
id: Optional[int] = Field(
default=None,
sa_column_kwargs={"primary_key": True},
sa_column=Column(Integer, primary_key=True),
)


def test_sa_column_no_primary_key() -> None:
with pytest.raises(RuntimeError):

class Item(SQLModel, table=True):
id: Optional[int] = Field(
default=None,
primary_key=True,
sa_column=Column(Integer, primary_key=True),
)


def test_sa_column_no_nullable() -> None:
with pytest.raises(RuntimeError):

class Item(SQLModel, table=True):
id: Optional[int] = Field(
default=None,
nullable=True,
sa_column=Column(Integer, primary_key=True),
)


def test_sa_column_no_foreign_key() -> None:
with pytest.raises(RuntimeError):

class Team(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str

class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
team_id: Optional[int] = Field(
default=None,
foreign_key="team.id",
sa_column=Column(Integer, primary_key=True),
)


def test_sa_column_no_unique() -> None:
with pytest.raises(RuntimeError):

class Item(SQLModel, table=True):
id: Optional[int] = Field(
default=None,
unique=True,
sa_column=Column(Integer, primary_key=True),
)


def test_sa_column_no_index() -> None:
with pytest.raises(RuntimeError):

class Item(SQLModel, table=True):
id: Optional[int] = Field(
default=None,
index=True,
sa_column=Column(Integer, primary_key=True),
)
53 changes: 53 additions & 0 deletions tests/test_field_sa_relationship.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import List, Optional

import pytest
from sqlalchemy.orm import relationship
from sqlmodel import Field, Relationship, SQLModel


def test_sa_relationship_no_args() -> None:
with pytest.raises(RuntimeError):

class Team(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(index=True)
headquarters: str

heroes: List["Hero"] = Relationship(
back_populates="team",
sa_relationship_args=["Hero"],
sa_relationship=relationship("Hero", back_populates="team"),
)

class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(index=True)
secret_name: str
age: Optional[int] = Field(default=None, index=True)

team_id: Optional[int] = Field(default=None, foreign_key="team.id")
team: Optional[Team] = Relationship(back_populates="heroes")


def test_sa_relationship_no_kwargs() -> None:
with pytest.raises(RuntimeError):

class Team(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(index=True)
headquarters: str

heroes: List["Hero"] = Relationship(
back_populates="team",
sa_relationship_kwargs={"lazy": "selectin"},
sa_relationship=relationship("Hero", back_populates="team"),
)

class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(index=True)
secret_name: str
age: Optional[int] = Field(default=None, index=True)

team_id: Optional[int] = Field(default=None, foreign_key="team.id")
team: Optional[Team] = Relationship(back_populates="heroes")