diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 201abc7c22..3b7cfaf95a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,11 +20,8 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: - - "3.7" - - "3.8" - - "3.9" - - "3.10" + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + pydantic-version: ["pydantic-v1", "pydantic-v2"] fail-fast: false steps: @@ -55,8 +52,12 @@ jobs: - name: Install Dependencies if: steps.cache.outputs.cache-hit != 'true' run: python -m poetry install - - name: Lint - run: python -m poetry run bash scripts/lint.sh + - name: Install Pydantic v1 + if: matrix.pydantic-version == 'pydantic-v1' + run: pip install "pydantic>=1.10.0,<2.0.0" + - name: Install Pydantic v2 + if: matrix.pydantic-version == 'pydantic-v2' + run: pip install "pydantic>=2.0.2,<3.0.0" - run: mkdir coverage - name: Test run: python -m poetry run bash scripts/test.sh @@ -68,6 +69,8 @@ jobs: with: name: coverage path: coverage + - name: Lint + run: python -m poetry run bash scripts/lint.sh coverage-combine: needs: - test diff --git a/docs/tutorial/fastapi/multiple-models.md b/docs/tutorial/fastapi/multiple-models.md index 6845b9862d..4ea24c6752 100644 --- a/docs/tutorial/fastapi/multiple-models.md +++ b/docs/tutorial/fastapi/multiple-models.md @@ -174,13 +174,13 @@ Now we use the type annotation `HeroCreate` for the request JSON data in the `he # Code below omitted 👇 ``` -Then we create a new `Hero` (this is the actual **table** model that saves things to the database) using `Hero.from_orm()`. +Then we create a new `Hero` (this is the actual **table** model that saves things to the database) using `Hero.model_validate()`. -The method `.from_orm()` reads data from another object with attributes and creates a new instance of this class, in this case `Hero`. +The method `.model_validate()` reads data from another object with attributes and creates a new instance of this class, in this case `Hero`. The alternative is `Hero.parse_obj()` that reads data from a dictionary. -But as in this case, we have a `HeroCreate` instance in the `hero` variable. This is an object with attributes, so we use `.from_orm()` to read those attributes. +But as in this case, we have a `HeroCreate` instance in the `hero` variable. This is an object with attributes, so we use `.model_validate()` to read those attributes. With this, we create a new `Hero` instance (the one for the database) and put it in the variable `db_hero` from the data in the `hero` variable that is the `HeroCreate` instance we received from the request. diff --git a/docs/tutorial/index.md b/docs/tutorial/index.md index 74107776c2..54e1147d68 100644 --- a/docs/tutorial/index.md +++ b/docs/tutorial/index.md @@ -64,6 +64,8 @@ $ cd sqlmodel-tutorial Make sure you have an officially supported version of Python. +Currently it is **Python 3.7** and above (Python 3.6 was already deprecated). + You can check which version you have with:
@@ -79,9 +81,11 @@ There's a chance that you have multiple Python versions installed. You might want to try with the specific versions, for example with: +* `python3.11` * `python3.10` * `python3.9` * `python3.8` +* `python3.7` The code would look like this: diff --git a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py index 3f0602e4b4..f305f75194 100644 --- a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py +++ b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py @@ -2,6 +2,7 @@ from fastapi import Depends, FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -54,7 +55,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/delete/tutorial001.py b/docs_src/tutorial/fastapi/delete/tutorial001.py index 3069fc5e87..f186c42b2b 100644 --- a/docs_src/tutorial/fastapi/delete/tutorial001.py +++ b/docs_src/tutorial/fastapi/delete/tutorial001.py @@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -50,7 +51,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py b/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py index 2b8739ca70..6701355f17 100644 --- a/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py +++ b/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py @@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -44,7 +45,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/multiple_models/tutorial001.py b/docs_src/tutorial/fastapi/multiple_models/tutorial001.py index df20123333..0ceed94ca1 100644 --- a/docs_src/tutorial/fastapi/multiple_models/tutorial001.py +++ b/docs_src/tutorial/fastapi/multiple_models/tutorial001.py @@ -2,6 +2,7 @@ from fastapi import FastAPI from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class Hero(SQLModel, table=True): @@ -46,7 +47,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/multiple_models/tutorial002.py b/docs_src/tutorial/fastapi/multiple_models/tutorial002.py index 392c2c5829..d92745a339 100644 --- a/docs_src/tutorial/fastapi/multiple_models/tutorial002.py +++ b/docs_src/tutorial/fastapi/multiple_models/tutorial002.py @@ -2,6 +2,7 @@ from fastapi import FastAPI from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -44,7 +45,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/read_one/tutorial001.py b/docs_src/tutorial/fastapi/read_one/tutorial001.py index 4d66e471a5..aa805d6c8f 100644 --- a/docs_src/tutorial/fastapi/read_one/tutorial001.py +++ b/docs_src/tutorial/fastapi/read_one/tutorial001.py @@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -44,7 +45,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/relationships/tutorial001.py b/docs_src/tutorial/fastapi/relationships/tutorial001.py index 8477e4a2a0..dfcedaf881 100644 --- a/docs_src/tutorial/fastapi/relationships/tutorial001.py +++ b/docs_src/tutorial/fastapi/relationships/tutorial001.py @@ -2,6 +2,7 @@ from fastapi import Depends, FastAPI, HTTPException, Query from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class TeamBase(SQLModel): @@ -92,7 +93,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) @@ -146,7 +150,10 @@ def delete_hero(*, session: Session = Depends(get_session), hero_id: int): @app.post("/teams/", response_model=TeamRead) def create_team(*, session: Session = Depends(get_session), team: TeamCreate): - db_team = Team.from_orm(team) + if IS_PYDANTIC_V2: + db_team = Team.model_validate(team) + else: + db_team = Team.from_orm(team) session.add(db_team) session.commit() session.refresh(db_team) diff --git a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py index 3f0602e4b4..f305f75194 100644 --- a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py +++ b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py @@ -2,6 +2,7 @@ from fastapi import Depends, FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -54,7 +55,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/teams/tutorial001.py b/docs_src/tutorial/fastapi/teams/tutorial001.py index 1da0dad8a2..46ea0f933c 100644 --- a/docs_src/tutorial/fastapi/teams/tutorial001.py +++ b/docs_src/tutorial/fastapi/teams/tutorial001.py @@ -2,6 +2,7 @@ from fastapi import Depends, FastAPI, HTTPException, Query from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class TeamBase(SQLModel): @@ -83,7 +84,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) @@ -137,7 +141,10 @@ def delete_hero(*, session: Session = Depends(get_session), hero_id: int): @app.post("/teams/", response_model=TeamRead) def create_team(*, session: Session = Depends(get_session), team: TeamCreate): - db_team = Team.from_orm(team) + if IS_PYDANTIC_V2: + db_team = Team.model_validate(team) + else: + db_team = Team.from_orm(team) session.add(db_team) session.commit() session.refresh(db_team) diff --git a/docs_src/tutorial/fastapi/update/tutorial001.py b/docs_src/tutorial/fastapi/update/tutorial001.py index bb98efd581..93dfa7496a 100644 --- a/docs_src/tutorial/fastapi/update/tutorial001.py +++ b/docs_src/tutorial/fastapi/update/tutorial001.py @@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -50,7 +51,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/many_to_many/tutorial003.py b/docs_src/tutorial/many_to_many/tutorial003.py index 1e03c4af89..cec6e56560 100644 --- a/docs_src/tutorial/many_to_many/tutorial003.py +++ b/docs_src/tutorial/many_to_many/tutorial003.py @@ -3,25 +3,12 @@ from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select -class HeroTeamLink(SQLModel, table=True): - team_id: Optional[int] = Field( - default=None, foreign_key="team.id", primary_key=True - ) - hero_id: Optional[int] = Field( - default=None, foreign_key="hero.id", primary_key=True - ) - is_training: bool = False - - team: "Team" = Relationship(back_populates="hero_links") - hero: "Hero" = Relationship(back_populates="team_links") - - class Team(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) name: str = Field(index=True) headquarters: str - hero_links: List[HeroTeamLink] = Relationship(back_populates="team") + hero_links: List["HeroTeamLink"] = Relationship(back_populates="team") class Hero(SQLModel, table=True): @@ -30,7 +17,20 @@ class Hero(SQLModel, table=True): secret_name: str age: Optional[int] = Field(default=None, index=True) - team_links: List[HeroTeamLink] = Relationship(back_populates="hero") + team_links: List["HeroTeamLink"] = Relationship(back_populates="hero") + + +class HeroTeamLink(SQLModel, table=True): + team_id: Optional[int] = Field( + default=None, foreign_key="team.id", primary_key=True + ) + hero_id: Optional[int] = Field( + default=None, foreign_key="hero.id", primary_key=True + ) + is_training: bool = False + + team: "Team" = Relationship(back_populates="hero_links") + hero: "Hero" = Relationship(back_populates="team_links") sqlite_file_name = "database.db" diff --git a/docs_src/tutorial/relationship_attributes/back_populates/tutorial003.py b/docs_src/tutorial/relationship_attributes/back_populates/tutorial003.py index 98e197002e..8d91a0bc25 100644 --- a/docs_src/tutorial/relationship_attributes/back_populates/tutorial003.py +++ b/docs_src/tutorial/relationship_attributes/back_populates/tutorial003.py @@ -3,6 +3,21 @@ from sqlmodel import Field, Relationship, SQLModel, create_engine +class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + secret_name: str + age: Optional[int] = Field(default=None, index=True) + + team_id: Optional[int] = Field(default=None, foreign_key="team.id") + team: Optional["Team"] = Relationship(back_populates="heroes") + + weapon_id: Optional[int] = Field(default=None, foreign_key="weapon.id") + weapon: Optional["Weapon"] = Relationship(back_populates="hero") + + powers: List["Power"] = Relationship(back_populates="hero") + + class Weapon(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) name: str = Field(index=True) @@ -26,21 +41,6 @@ class Team(SQLModel, table=True): heroes: List["Hero"] = Relationship(back_populates="team") -class Hero(SQLModel, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - name: str = Field(index=True) - secret_name: str - age: Optional[int] = Field(default=None, index=True) - - team_id: Optional[int] = Field(default=None, foreign_key="team.id") - team: Optional[Team] = Relationship(back_populates="heroes") - - weapon_id: Optional[int] = Field(default=None, foreign_key="weapon.id") - weapon: Optional[Weapon] = Relationship(back_populates="hero") - - powers: List[Power] = Relationship(back_populates="hero") - - sqlite_file_name = "database.db" sqlite_url = f"sqlite:///{sqlite_file_name}" diff --git a/pyproject.toml b/pyproject.toml index 23fa79bf31..f104631655 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ classifiers = [ "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Topic :: Database", "Topic :: Database :: Database Engines/Servers", "Topic :: Internet", @@ -45,7 +46,7 @@ pillow = "^9.3.0" cairosvg = "^2.5.2" mdx-include = "^1.4.1" coverage = {extras = ["toml"], version = ">=6.2,<8.0"} -fastapi = "^0.68.1" +fastapi = "^0.100.0" requests = "^2.26.0" ruff = "^0.1.2" diff --git a/sqlmodel/__init__.py b/sqlmodel/__init__.py index 495ac9c8a8..7e20e1ba41 100644 --- a/sqlmodel/__init__.py +++ b/sqlmodel/__init__.py @@ -21,7 +21,6 @@ from sqlalchemy.schema import PrimaryKeyConstraint as PrimaryKeyConstraint from sqlalchemy.schema import Sequence as Sequence from sqlalchemy.schema import Table as Table -from sqlalchemy.schema import ThreadLocalMetaData as ThreadLocalMetaData from sqlalchemy.schema import UniqueConstraint as UniqueConstraint from sqlalchemy.sql import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT from sqlalchemy.sql import ( @@ -31,6 +30,7 @@ from sqlalchemy.sql import ( LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL, ) +from sqlalchemy.sql import Subquery as Subquery from sqlalchemy.sql import alias as alias from sqlalchemy.sql import all_ as all_ from sqlalchemy.sql import and_ as and_ @@ -71,7 +71,6 @@ from sqlalchemy.sql import outerjoin as outerjoin from sqlalchemy.sql import outparam as outparam from sqlalchemy.sql import over as over -from sqlalchemy.sql import subquery as subquery from sqlalchemy.sql import table as table from sqlalchemy.sql import tablesample as tablesample from sqlalchemy.sql import text as text diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py new file mode 100644 index 0000000000..dbd22053a8 --- /dev/null +++ b/sqlmodel/compat.py @@ -0,0 +1,416 @@ +import ipaddress +import uuid +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from enum import Enum +from pathlib import Path +from types import NoneType +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + ForwardRef, + Optional, + Sequence, + Type, + TypeVar, + Union, + cast, + get_args, + get_origin, +) + +from pydantic import VERSION as PYDANTIC_VERSION +from sqlalchemy import ( + Boolean, + Column, + Date, + DateTime, + Float, + ForeignKey, + Integer, + Interval, + Numeric, +) +from sqlalchemy import Enum as sa_Enum +from sqlalchemy.sql.sqltypes import LargeBinary, Time + +from .sql.sqltypes import GUID, AutoString + +IS_PYDANTIC_V2 = int(PYDANTIC_VERSION.split(".")[0]) >= 2 + + +if IS_PYDANTIC_V2: + from pydantic import ConfigDict as PydanticModelConfig + from pydantic._internal._fields import PydanticMetadata + from pydantic._internal._model_construction import ModelMetaclass + from pydantic_core import PydanticUndefined as PydanticUndefined # noqa + from pydantic_core import PydanticUndefinedType as PydanticUndefinedType +else: + from pydantic import BaseConfig as PydanticModelConfig + from pydantic.fields import SHAPE_SINGLETON, ModelField + from pydantic.fields import Undefined as PydanticUndefined # noqa + from pydantic.fields import UndefinedType as PydanticUndefinedType + from pydantic.main import ModelMetaclass as ModelMetaclass + from pydantic.typing import resolve_annotations + +if TYPE_CHECKING: + from .main import FieldInfo, RelationshipInfo, SQLModel, SQLModelMetaclass + + +NoArgAnyCallable = Callable[[], Any] +T = TypeVar("T") +InstanceOrType = Union[T, Type[T]] + +if IS_PYDANTIC_V2: + + class SQLModelConfig(PydanticModelConfig, total=False): + table: Optional[bool] + registry: Optional[Any] + +else: + + class SQLModelConfig(PydanticModelConfig): + table: Optional[bool] = None + registry: Optional[Any] = None + + +# Inspired from https://github.com/roman-right/beanie/blob/main/beanie/odm/utils/pydantic.py +def get_model_config(model: type) -> Optional[SQLModelConfig]: + if IS_PYDANTIC_V2: + return getattr(model, "model_config", None) + else: + return getattr(model, "__config__", None) + + +def get_config_value( + model: InstanceOrType["SQLModel"], parameter: str, default: Any = None +) -> Any: + if IS_PYDANTIC_V2: + return model.model_config.get(parameter, default) + else: + return getattr(model.__config__, parameter, default) + + +def set_config_value( + model: InstanceOrType["SQLModel"], + parameter: str, + value: Any, + v1_parameter: Optional[str] = None, +) -> None: + if IS_PYDANTIC_V2: + model.model_config[parameter] = value # type: ignore + else: + setattr(model.__config__, v1_parameter or parameter, value) # type: ignore + + +def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]: + if IS_PYDANTIC_V2: + return model.model_fields # type: ignore + else: + return model.__fields__ # type: ignore + + +def get_fields_set(model: InstanceOrType["SQLModel"]) -> set[str]: + if IS_PYDANTIC_V2: + return model.__pydantic_fields_set__ # type: ignore + else: + return model.__fields_set__ # type: ignore + + +def set_fields_set( + new_object: InstanceOrType["SQLModel"], fields: set["FieldInfo"] +) -> None: + if IS_PYDANTIC_V2: + object.__setattr__(new_object, "__pydantic_fields_set__", fields) + else: + object.__setattr__(new_object, "__fields_set__", fields) + + +def set_attribute_mode(cls: Type["SQLModelMetaclass"]) -> None: + if IS_PYDANTIC_V2: + cls.model_config["read_from_attributes"] = True + else: + cls.__config__.read_with_orm_mode = True # type: ignore + + +def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: + if IS_PYDANTIC_V2: + return class_dict.get("__annotations__", {}) + else: + return resolve_annotations( + class_dict.get("__annotations__", {}), class_dict.get("__module__", None) + ) + + +def class_dict_is_table( + class_dict: dict[str, Any], class_kwargs: dict[str, Any] +) -> bool: + config: SQLModelConfig = {} + if IS_PYDANTIC_V2: + config = class_dict.get("model_config", {}) + else: + config = class_dict.get("__config__", {}) + config_table = config.get("table", PydanticUndefined) + if config_table is not PydanticUndefined: + return config_table # type: ignore + kw_table = class_kwargs.get("table", PydanticUndefined) + if kw_table is not PydanticUndefined: + return kw_table # type: ignore + return False + + +def cls_is_table(cls: Type) -> bool: + if IS_PYDANTIC_V2: + config = getattr(cls, "model_config", None) + if not config: + return False + return config.get("table", False) + else: + config = getattr(cls, "__config__", None) + if not config: + return False + return getattr(config, "table", False) + + +def get_relationship_to( + name: str, + rel_info: "RelationshipInfo", + annotation: Any, +) -> Any: + if IS_PYDANTIC_V2: + relationship_to = get_origin(annotation) + # Direct relationships (e.g. 'Team' or Team) have None as an origin + if relationship_to is None: + relationship_to = annotation + # If Union (e.g. Optional), get the real field + elif relationship_to is Union: + relationship_to = get_args(annotation)[0] + # If a list, then also get the real field + elif relationship_to is list: + relationship_to = get_args(annotation)[0] + if isinstance(relationship_to, ForwardRef): + relationship_to = relationship_to.__forward_arg__ + return relationship_to + else: + temp_field = ModelField.infer( + name=name, + value=rel_info, + annotation=annotation, + class_validators=None, + config=SQLModelConfig, + ) + relationship_to = temp_field.type_ + if isinstance(temp_field.type_, ForwardRef): + relationship_to = temp_field.type_.__forward_arg__ + return relationship_to + + +def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any]) -> None: + """ + Pydantic v2 without required fields with no optionals cannot do empty initialisations. + This means we cannot do Table() and set fields later. + We go around this by adding a default to everything, being None + + Args: + annotations: Dict[str, Any]: The annotations to provide to pydantic + class_dict: Dict[str, Any]: The class dict for the defaults + """ + if IS_PYDANTIC_V2: + from .main import FieldInfo + + # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything + for key in annotations.keys(): + value = class_dict.get(key, PydanticUndefined) + if value is PydanticUndefined: + class_dict[key] = None + elif isinstance(value, FieldInfo): + if ( + value.default in (PydanticUndefined, Ellipsis) + ) and value.default_factory is None: + # So we can check for nullable + value.default = None + + +def _is_field_noneable(field: "FieldInfo") -> bool: + if IS_PYDANTIC_V2: + if getattr(field, "nullable", PydanticUndefined) is not PydanticUndefined: + return field.nullable # type: ignore + if not field.is_required(): + if field.default is PydanticUndefined: + return False + if field.annotation is None or field.annotation is NoneType: + return True + if get_origin(field.annotation) is Union: + for base in get_args(field.annotation): + if base is NoneType: + return True + return False + return False + else: + if not field.required: + # Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947) + return field.allow_none and ( + field.shape != SHAPE_SINGLETON or not field.sub_fields + ) + return field.allow_none + + +def get_sqlalchemy_type(field: Any) -> Any: + if IS_PYDANTIC_V2: + field_info = field + else: + field_info = field.field_info + sa_type = getattr(field_info, "sa_type", PydanticUndefined) # noqa: B009 + if sa_type is not PydanticUndefined: + return sa_type + + type_ = get_type_from_field(field) + metadata = get_field_metadata(field) + + # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI + if issubclass(type_, Enum): + return sa_Enum(type_) + if issubclass(type_, str): + max_length = getattr(metadata, "max_length", None) + if max_length: + return AutoString(length=max_length) + return AutoString + if issubclass(type_, float): + return Float + if issubclass(type_, bool): + return Boolean + if issubclass(type_, int): + return Integer + if issubclass(type_, datetime): + return DateTime + if issubclass(type_, date): + return Date + if issubclass(type_, timedelta): + return Interval + if issubclass(type_, time): + return Time + if issubclass(type_, bytes): + return LargeBinary + if issubclass(type_, Decimal): + return Numeric( + precision=getattr(metadata, "max_digits", None), + scale=getattr(metadata, "decimal_places", None), + ) + if issubclass(type_, ipaddress.IPv4Address): + return AutoString + if issubclass(type_, ipaddress.IPv4Network): + return AutoString + if issubclass(type_, ipaddress.IPv6Address): + return AutoString + if issubclass(type_, ipaddress.IPv6Network): + return AutoString + if issubclass(type_, Path): + return AutoString + if issubclass(type_, uuid.UUID): + return GUID + raise ValueError(f"{type_} has no matching SQLAlchemy type") + + +def get_type_from_field(field: Any) -> type: + if IS_PYDANTIC_V2: + type_: type | None = field.annotation + # Resolve Optional fields + if type_ is None: + raise ValueError("Missing field type") + origin = get_origin(type_) + if origin is None: + return type_ + if origin is Union: + bases = get_args(type_) + if len(bases) > 2: + raise ValueError( + "Cannot have a (non-optional) union as a SQL alchemy field" + ) + # Non optional unions are not allowed + if bases[0] is not NoneType and bases[1] is not NoneType: + raise ValueError( + "Cannot have a (non-optional) union as a SQL alchemy field" + ) + # Optional unions are allowed + return bases[0] if bases[0] is not NoneType else bases[1] + return origin + else: + if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: + return field.type_ + raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") + + +class FakeMetadata: + max_length: Optional[int] = None + max_digits: Optional[int] = None + decimal_places: Optional[int] = None + + +def get_field_metadata(field: Any) -> Any: + if IS_PYDANTIC_V2: + for meta in field.metadata: + if isinstance(meta, PydanticMetadata): + return meta + return FakeMetadata() + else: + metadata = FakeMetadata() + metadata.max_length = field.field_info.max_length + metadata.max_digits = getattr(field.type_, "max_digits", None) + metadata.decimal_places = getattr(field.type_, "decimal_places", None) + return metadata + + +def get_column_from_field(field: Any) -> Column: # type: ignore + if IS_PYDANTIC_V2: + field_info = field + else: + field_info = field.field_info + sa_column = getattr(field_info, "sa_column", PydanticUndefined) + if isinstance(sa_column, Column): + return sa_column + sa_type = get_sqlalchemy_type(field) + primary_key = getattr(field_info, "primary_key", PydanticUndefined) + if primary_key is PydanticUndefined: + primary_key = False + index = getattr(field_info, "index", PydanticUndefined) + if index is PydanticUndefined: + index = False + nullable = not primary_key and _is_field_noneable(field) + # Override derived nullability if the nullable property is set explicitly + # on the field + field_nullable = getattr(field_info, "nullable", PydanticUndefined) # noqa: B009 + if field_nullable is not PydanticUndefined: + assert not isinstance(field_nullable, PydanticUndefinedType) + nullable = field_nullable + args = [] + foreign_key = getattr(field_info, "foreign_key", PydanticUndefined) + if foreign_key is PydanticUndefined: + foreign_key = None + unique = getattr(field_info, "unique", PydanticUndefined) + if unique is PydanticUndefined: + unique = False + if foreign_key: + assert isinstance(foreign_key, str) + args.append(ForeignKey(foreign_key)) + kwargs = { + "primary_key": primary_key, + "nullable": nullable, + "index": index, + "unique": unique, + } + sa_default = PydanticUndefined + if field_info.default_factory: + sa_default = field_info.default_factory + elif field_info.default is not PydanticUndefined: + sa_default = field_info.default + if sa_default is not PydanticUndefined: + kwargs["default"] = sa_default + sa_column_args = getattr(field_info, "sa_column_args", PydanticUndefined) + if sa_column_args is not PydanticUndefined: + args.extend(list(cast(Sequence[Any], sa_column_args))) + sa_column_kwargs = getattr(field_info, "sa_column_kwargs", PydanticUndefined) + if sa_column_kwargs is not PydanticUndefined: + kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) + return Column(sa_type, *args, **kwargs) # type: ignore diff --git a/sqlmodel/engine/create.py b/sqlmodel/engine/create.py index b2d567b1b1..97481259e2 100644 --- a/sqlmodel/engine/create.py +++ b/sqlmodel/engine/create.py @@ -136,4 +136,4 @@ def create_engine( if not isinstance(query_cache_size, _DefaultPlaceholder): current_kwargs["query_cache_size"] = query_cache_size current_kwargs.update(kwargs) - return _create_engine(url, **current_kwargs) # type: ignore + return _create_engine(url, **current_kwargs) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 2b69dd2a75..cb008bb663 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -1,17 +1,10 @@ -import ipaddress -import uuid import weakref -from datetime import date, datetime, time, timedelta -from decimal import Decimal -from enum import Enum -from pathlib import Path from typing import ( AbstractSet, Any, Callable, ClassVar, Dict, - ForwardRef, List, Mapping, Optional, @@ -25,34 +18,43 @@ overload, ) -from pydantic import BaseConfig, BaseModel -from pydantic.errors import ConfigError, DictError -from pydantic.fields import SHAPE_SINGLETON, ModelField, Undefined, UndefinedType +from pydantic import BaseModel from pydantic.fields import FieldInfo as PydanticFieldInfo -from pydantic.main import ModelMetaclass, validate_model -from pydantic.typing import NoArgAnyCallable, resolve_annotations -from pydantic.utils import ROOT_KEY, Representation +from pydantic.utils import Representation from sqlalchemy import ( - Boolean, Column, - Date, - DateTime, - Float, - ForeignKey, - Integer, - Interval, - Numeric, inspect, ) -from sqlalchemy import Enum as sa_Enum from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship from sqlalchemy.orm.attributes import set_attribute from sqlalchemy.orm.decl_api import DeclarativeMeta from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData -from sqlalchemy.sql.sqltypes import LargeBinary, Time -from .sql.sqltypes import GUID, AutoString +from .compat import ( + IS_PYDANTIC_V2, + ModelMetaclass, + NoArgAnyCallable, + PydanticModelConfig, + PydanticUndefined, + PydanticUndefinedType, + SQLModelConfig, + class_dict_is_table, + cls_is_table, + get_annotations, + get_column_from_field, + get_config_value, + get_model_fields, + get_relationship_to, + set_config_value, + set_empty_defaults, + set_fields_set, +) + +if not IS_PYDANTIC_V2: + from pydantic.errors import ConfigError, DictError + from pydantic.main import validate_model + from pydantic.utils import ROOT_KEY _T = TypeVar("_T") @@ -68,50 +70,50 @@ def __dataclass_transform__( class FieldInfo(PydanticFieldInfo): - def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: + def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None: primary_key = kwargs.pop("primary_key", False) - nullable = kwargs.pop("nullable", Undefined) - foreign_key = kwargs.pop("foreign_key", Undefined) + nullable = kwargs.pop("nullable", PydanticUndefined) + foreign_key = kwargs.pop("foreign_key", PydanticUndefined) unique = kwargs.pop("unique", False) - index = kwargs.pop("index", Undefined) - sa_type = kwargs.pop("sa_type", Undefined) - sa_column = kwargs.pop("sa_column", Undefined) - sa_column_args = kwargs.pop("sa_column_args", Undefined) - sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined) - if sa_column is not Undefined: - if sa_column_args is not Undefined: + index = kwargs.pop("index", PydanticUndefined) + sa_type = kwargs.pop("sa_type", PydanticUndefined) + sa_column = kwargs.pop("sa_column", PydanticUndefined) + sa_column_args = kwargs.pop("sa_column_args", PydanticUndefined) + sa_column_kwargs = kwargs.pop("sa_column_kwargs", PydanticUndefined) + if sa_column is not PydanticUndefined: + if sa_column_args is not PydanticUndefined: raise RuntimeError( "Passing sa_column_args is not supported when " "also passing a sa_column" ) - if sa_column_kwargs is not Undefined: + if sa_column_kwargs is not PydanticUndefined: raise RuntimeError( "Passing sa_column_kwargs is not supported when " "also passing a sa_column" ) - if primary_key is not Undefined: + if primary_key is not PydanticUndefined: raise RuntimeError( "Passing primary_key is not supported when " "also passing a sa_column" ) - if nullable is not Undefined: + if nullable is not PydanticUndefined: raise RuntimeError( "Passing nullable is not supported when " "also passing a sa_column" ) - if foreign_key is not Undefined: + if foreign_key is not PydanticUndefined: raise RuntimeError( "Passing foreign_key is not supported when " "also passing a sa_column" ) - if unique is not Undefined: + if unique is not PydanticUndefined: raise RuntimeError( "Passing unique is not supported when also passing a sa_column" ) - if index is not Undefined: + if index is not PydanticUndefined: raise RuntimeError( "Passing index is not supported when also passing a sa_column" ) - if sa_type is not Undefined: + if sa_type is not PydanticUndefined: raise RuntimeError( "Passing sa_type is not supported when also passing a sa_column" ) @@ -157,7 +159,7 @@ def __init__( @overload def Field( - default: Any = Undefined, + default: Any = PydanticUndefined, *, default_factory: Optional[NoArgAnyCallable] = None, alias: Optional[str] = None, @@ -186,14 +188,16 @@ def Field( regex: Optional[str] = None, discriminator: Optional[str] = None, repr: bool = True, - primary_key: Union[bool, UndefinedType] = Undefined, - foreign_key: Any = Undefined, - unique: Union[bool, UndefinedType] = Undefined, - nullable: Union[bool, UndefinedType] = Undefined, - index: Union[bool, UndefinedType] = Undefined, - sa_type: Union[Type[Any], UndefinedType] = Undefined, - sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, - sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, + primary_key: Union[bool, PydanticUndefinedType] = PydanticUndefined, + foreign_key: Any = PydanticUndefined, + unique: Union[bool, PydanticUndefinedType] = PydanticUndefined, + nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined, + index: Union[bool, PydanticUndefinedType] = PydanticUndefined, + sa_type: Union[Type[Any], PydanticUndefinedType] = PydanticUndefined, + sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined, + sa_column_kwargs: Union[ + Mapping[str, Any], PydanticUndefinedType + ] = PydanticUndefined, schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: ... @@ -201,7 +205,7 @@ def Field( @overload def Field( - default: Any = Undefined, + default: Any = PydanticUndefined, *, default_factory: Optional[NoArgAnyCallable] = None, alias: Optional[str] = None, @@ -230,14 +234,14 @@ def Field( regex: Optional[str] = None, discriminator: Optional[str] = None, repr: bool = True, - sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore + sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: ... def Field( - default: Any = Undefined, + default: Any = PydanticUndefined, *, default_factory: Optional[NoArgAnyCallable] = None, alias: Optional[str] = None, @@ -266,15 +270,17 @@ def Field( regex: Optional[str] = None, discriminator: Optional[str] = None, repr: bool = True, - primary_key: Union[bool, UndefinedType] = Undefined, - foreign_key: Any = Undefined, - unique: Union[bool, UndefinedType] = Undefined, - nullable: Union[bool, UndefinedType] = Undefined, - index: Union[bool, UndefinedType] = Undefined, - sa_type: Union[Type[Any], UndefinedType] = Undefined, - sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore - sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, - sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, + primary_key: Union[bool, PydanticUndefinedType] = PydanticUndefined, + foreign_key: Any = PydanticUndefined, + unique: Union[bool, PydanticUndefinedType] = PydanticUndefined, + nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined, + index: Union[bool, PydanticUndefinedType] = PydanticUndefined, + sa_type: Union[Type[Any], PydanticUndefinedType] = PydanticUndefined, + sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore + sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined, + sa_column_kwargs: Union[ + Mapping[str, Any], PydanticUndefinedType + ] = PydanticUndefined, schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: current_schema_extra = schema_extra or {} @@ -314,7 +320,6 @@ def Field( sa_column_kwargs=sa_column_kwargs, **current_schema_extra, ) - field_info._validate() return field_info @@ -343,7 +348,7 @@ def Relationship( *, back_populates: Optional[str] = None, link_model: Optional[Any] = None, - sa_relationship: Optional[RelationshipProperty] = None, # type: ignore + sa_relationship: Optional[RelationshipProperty] = None, sa_relationship_args: Optional[Sequence[Any]] = None, sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, ) -> Any: @@ -360,18 +365,22 @@ def Relationship( @__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): __sqlmodel_relationships__: Dict[str, RelationshipInfo] - __config__: Type[BaseConfig] - __fields__: Dict[str, ModelField] + if IS_PYDANTIC_V2: + model_config: SQLModelConfig + model_fields: Dict[str, FieldInfo] + else: + __config__: Type[SQLModelConfig] + __fields__: Dict[str, FieldInfo] # Replicate SQLAlchemy def __setattr__(cls, name: str, value: Any) -> None: - if getattr(cls.__config__, "table", False): + if get_config_value(cls, "table", False): DeclarativeMeta.__setattr__(cls, name, value) else: super().__setattr__(name, value) def __delattr__(cls, name: str) -> None: - if getattr(cls.__config__, "table", False): + if get_config_value(cls, "table", False): DeclarativeMeta.__delattr__(cls, name) else: super().__delattr__(name) @@ -386,9 +395,7 @@ def __new__( ) -> Any: relationships: Dict[str, RelationshipInfo] = {} dict_for_pydantic = {} - original_annotations = resolve_annotations( - class_dict.get("__annotations__", {}), class_dict.get("__module__", None) - ) + original_annotations = get_annotations(class_dict) pydantic_annotations = {} relationship_annotations = {} for k, v in class_dict.items(): @@ -412,17 +419,20 @@ def __new__( # superclass causing an error allowed_config_kwargs: Set[str] = { key - for key in dir(BaseConfig) + for key in dir(PydanticModelConfig) if not ( key.startswith("__") and key.endswith("__") ) # skip dunder methods and attributes } - pydantic_kwargs = kwargs.copy() config_kwargs = { - key: pydantic_kwargs.pop(key) - for key in pydantic_kwargs.keys() & allowed_config_kwargs + key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs } - new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) + if class_dict_is_table(class_dict, kwargs): + set_empty_defaults(pydantic_annotations, dict_used) + + new_cls: Type["SQLModelMetaclass"] = super().__new__( + cls, name, bases, dict_used, **config_kwargs + ) new_cls.__annotations__ = { **relationship_annotations, **pydantic_annotations, @@ -430,33 +440,35 @@ def __new__( } def get_config(name: str) -> Any: - config_class_value = getattr(new_cls.__config__, name, Undefined) - if config_class_value is not Undefined: + config_class_value = get_config_value(new_cls, name, PydanticUndefined) + if config_class_value is not PydanticUndefined: return config_class_value - kwarg_value = kwargs.get(name, Undefined) - if kwarg_value is not Undefined: + kwarg_value = kwargs.get(name, PydanticUndefined) + if kwarg_value is not PydanticUndefined: return kwarg_value - return Undefined + return PydanticUndefined config_table = get_config("table") if config_table is True: # If it was passed by kwargs, ensure it's also set in config - new_cls.__config__.table = config_table - for k, v in new_cls.__fields__.items(): + set_config_value(new_cls, "table", config_table) + for k, v in get_model_fields(new_cls).items(): col = get_column_from_field(v) setattr(new_cls, k, col) # Set a config flag to tell FastAPI that this should be read with a field # in orm_mode instead of preemptively converting it to a dict. - # This could be done by reading new_cls.__config__.table in FastAPI, but + # This could be done by reading new_cls.model_config['table'] in FastAPI, but # that's very specific about SQLModel, so let's have another config that # other future tools based on Pydantic can use. - new_cls.__config__.read_with_orm_mode = True + set_config_value( + new_cls, "read_from_attributes", True, v1_parameter="read_with_orm_mode" + ) config_registry = get_config("registry") - if config_registry is not Undefined: + if config_registry is not PydanticUndefined: config_registry = cast(registry, config_registry) # If it was passed by kwargs, ensure it's also set in config - new_cls.__config__.registry = config_table + set_config_value(new_cls, "registry", config_table) setattr(new_cls, "_sa_registry", config_registry) # noqa: B010 setattr(new_cls, "metadata", config_registry.metadata) # noqa: B010 setattr(new_cls, "__abstract__", True) # noqa: B010 @@ -470,13 +482,8 @@ def __init__( # this allows FastAPI cloning a SQLModel for the response_model without # trying to create a new SQLAlchemy, for a new table, with the same name, that # triggers an error - base_is_table = False - for base in bases: - config = getattr(base, "__config__") # noqa: B009 - if config and getattr(config, "table", False): - base_is_table = True - break - if getattr(cls.__config__, "table", False) and not base_is_table: + base_is_table = any(cls_is_table(base) for base in bases) + if cls_is_table(cls) and not base_is_table: for rel_name, rel_info in cls.__sqlmodel_relationships__.items(): if rel_info.sa_relationship: # There's a SQLAlchemy relationship declared, that takes precedence @@ -484,16 +491,7 @@ def __init__( setattr(cls, rel_name, rel_info.sa_relationship) # Fix #315 continue ann = cls.__annotations__[rel_name] - temp_field = ModelField.infer( - name=rel_name, - value=rel_info, - annotation=ann, - class_validators=None, - config=BaseConfig, - ) - relationship_to = temp_field.type_ - if isinstance(temp_field.type_, ForwardRef): - relationship_to = temp_field.type_.__forward_arg__ + relationship_to = get_relationship_to(rel_name, rel_info, ann) rel_kwargs: Dict[str, Any] = {} if rel_info.back_populates: rel_kwargs["back_populates"] = rel_info.back_populates @@ -511,7 +509,7 @@ def __init__( rel_args.extend(rel_info.sa_relationship_args) if rel_info.sa_relationship_kwargs: rel_kwargs.update(rel_info.sa_relationship_kwargs) - rel_value: RelationshipProperty = relationship( # type: ignore + rel_value: RelationshipProperty = relationship( relationship_to, *rel_args, **rel_kwargs ) setattr(cls, rel_name, rel_value) # Fix #315 @@ -523,104 +521,6 @@ def __init__( ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) -def get_sqlalchemy_type(field: ModelField) -> Any: - sa_type = getattr(field.field_info, "sa_type", Undefined) # noqa: B009 - if sa_type is not Undefined: - return sa_type - if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: - # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI - if issubclass(field.type_, Enum): - return sa_Enum(field.type_) - if issubclass(field.type_, str): - if field.field_info.max_length: - return AutoString(length=field.field_info.max_length) - return AutoString - if issubclass(field.type_, float): - return Float - if issubclass(field.type_, bool): - return Boolean - if issubclass(field.type_, int): - return Integer - if issubclass(field.type_, datetime): - return DateTime - if issubclass(field.type_, date): - return Date - if issubclass(field.type_, timedelta): - return Interval - if issubclass(field.type_, time): - return Time - if issubclass(field.type_, bytes): - return LargeBinary - if issubclass(field.type_, Decimal): - return Numeric( - precision=getattr(field.type_, "max_digits", None), - scale=getattr(field.type_, "decimal_places", None), - ) - if issubclass(field.type_, ipaddress.IPv4Address): - return AutoString - if issubclass(field.type_, ipaddress.IPv4Network): - return AutoString - if issubclass(field.type_, ipaddress.IPv6Address): - return AutoString - if issubclass(field.type_, ipaddress.IPv6Network): - return AutoString - if issubclass(field.type_, Path): - return AutoString - if issubclass(field.type_, uuid.UUID): - return GUID - raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") - - -def get_column_from_field(field: ModelField) -> Column: # type: ignore - sa_column = getattr(field.field_info, "sa_column", Undefined) - if isinstance(sa_column, Column): - return sa_column - sa_type = get_sqlalchemy_type(field) - primary_key = getattr(field.field_info, "primary_key", Undefined) - if primary_key is Undefined: - primary_key = False - index = getattr(field.field_info, "index", Undefined) - if index is Undefined: - index = False - nullable = not primary_key and _is_field_noneable(field) - # Override derived nullability if the nullable property is set explicitly - # on the field - field_nullable = getattr(field.field_info, "nullable", Undefined) # noqa: B009 - if field_nullable != Undefined: - assert not isinstance(field_nullable, UndefinedType) - nullable = field_nullable - args = [] - foreign_key = getattr(field.field_info, "foreign_key", Undefined) - if foreign_key is Undefined: - foreign_key = None - unique = getattr(field.field_info, "unique", Undefined) - if unique is Undefined: - unique = False - if foreign_key: - assert isinstance(foreign_key, str) - args.append(ForeignKey(foreign_key)) - kwargs = { - "primary_key": primary_key, - "nullable": nullable, - "index": index, - "unique": unique, - } - sa_default = Undefined - if field.field_info.default_factory: - sa_default = field.field_info.default_factory - elif field.field_info.default is not Undefined: - sa_default = field.field_info.default - if sa_default is not Undefined: - kwargs["default"] = sa_default - sa_column_args = getattr(field.field_info, "sa_column_args", Undefined) - if sa_column_args is not Undefined: - args.extend(list(cast(Sequence[Any], sa_column_args))) - sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined) - if sa_column_kwargs is not Undefined: - kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) - return Column(sa_type, *args, **kwargs) # type: ignore - - class_registry = weakref.WeakValueDictionary() # type: ignore default_registry = registry() @@ -639,12 +539,17 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values __slots__ = ("__weakref__",) __tablename__: ClassVar[Union[str, Callable[..., str]]] - __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore + __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] __name__: ClassVar[str] metadata: ClassVar[MetaData] + __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six - class Config: - orm_mode = True + if IS_PYDANTIC_V2: + model_config = SQLModelConfig(from_attributes=True) + else: + + class Config: + orm_mode = True def __new__(cls, *args: Any, **kwargs: Any) -> Any: new_object = super().__new__(cls) @@ -653,28 +558,35 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any: # Set __fields_set__ here, that would have been set when calling __init__ # in the Pydantic model so that when SQLAlchemy sets attributes that are # added (e.g. when querying from DB) to the __fields_set__, this already exists - object.__setattr__(new_object, "__fields_set__", set()) + set_fields_set(new_object, set()) return new_object def __init__(__pydantic_self__, **data: Any) -> None: # Uses something other than `self` the first arg to allow "self" as a # settable attribute - values, fields_set, validation_error = validate_model( - __pydantic_self__.__class__, data - ) - # Only raise errors if not a SQLModel model - if ( - not getattr(__pydantic_self__.__config__, "table", False) - and validation_error - ): - raise validation_error - # Do not set values as in Pydantic, pass them through setattr, so SQLAlchemy - # can handle them - # object.__setattr__(__pydantic_self__, '__dict__', values) - for key, value in values.items(): - setattr(__pydantic_self__, key, value) - object.__setattr__(__pydantic_self__, "__fields_set__", fields_set) - non_pydantic_keys = data.keys() - values.keys() + if IS_PYDANTIC_V2: + old_dict = __pydantic_self__.__dict__.copy() + super().__init__(**data) # noqa + __pydantic_self__.__dict__ = {**old_dict, **__pydantic_self__.__dict__} + non_pydantic_keys = data.keys() - __pydantic_self__.model_fields + else: + values, fields_set, validation_error = validate_model( + __pydantic_self__.__class__, data + ) + # Only raise errors if not a SQLModel model + if ( + not getattr(__pydantic_self__.__config__, "table", False) # noqa + and validation_error + ): + raise validation_error + # Do not set values as in Pydantic, pass them through setattr, so SQLAlchemy + # can handle them + # object.__setattr__(__pydantic_self__, '__dict__', values) + for key, value in values.items(): + setattr(__pydantic_self__, key, value) + object.__setattr__(__pydantic_self__, "__fields_set__", fields_set) + non_pydantic_keys = data.keys() - values.keys() + for key in non_pydantic_keys: if key in __pydantic_self__.__sqlmodel_relationships__: setattr(__pydantic_self__, key, data[key]) @@ -685,59 +597,13 @@ def __setattr__(self, name: str, value: Any) -> None: return else: # Set in SQLAlchemy, before Pydantic to trigger events and updates - if getattr(self.__config__, "table", False) and is_instrumented(self, name): + if get_config_value(self, "table", False) and is_instrumented(self, name): set_attribute(self, name, value) # Set in Pydantic model to trigger possible validation changes, only for # non relationship values if name not in self.__sqlmodel_relationships__: super().__setattr__(name, value) - @classmethod - def from_orm( - cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None - ) -> _TSQLModel: - # Duplicated from Pydantic - if not cls.__config__.orm_mode: - raise ConfigError( - "You must have the config attribute orm_mode=True to use from_orm" - ) - obj = {ROOT_KEY: obj} if cls.__custom_root_type__ else cls._decompose_class(obj) - # SQLModel, support update dict - if update is not None: - obj = {**obj, **update} - # End SQLModel support dict - if not getattr(cls.__config__, "table", False): - # If not table, normal Pydantic code - m: _TSQLModel = cls.__new__(cls) - else: - # If table, create the new instance normally to make SQLAlchemy create - # the _sa_instance_state attribute - m = cls() - values, fields_set, validation_error = validate_model(cls, obj) - if validation_error: - raise validation_error - # Updated to trigger SQLAlchemy internal handling - if not getattr(cls.__config__, "table", False): - object.__setattr__(m, "__dict__", values) - else: - for key, value in values.items(): - setattr(m, key, value) - # Continue with standard Pydantic logic - object.__setattr__(m, "__fields_set__", fields_set) - m._init_private_attributes() - return m - - @classmethod - def parse_obj( - cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None - ) -> _TSQLModel: - obj = cls._enforce_dict_if_root(obj) - # SQLModel, support update dict - if update is not None: - obj = {**obj, **update} - # End SQLModel support dict - return super().parse_obj(obj) - def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: # Don't show SQLAlchemy private attributes return [ @@ -746,78 +612,144 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: if not (isinstance(k, str) and k.startswith("_sa_")) ] - # From Pydantic, override to enforce validation with dict - @classmethod - def validate(cls: Type[_TSQLModel], value: Any) -> _TSQLModel: - if isinstance(value, cls): - return value.copy() if cls.__config__.copy_on_model_validation else value - - value = cls._enforce_dict_if_root(value) - if isinstance(value, dict): - values, fields_set, validation_error = validate_model(cls, value) - if validation_error: - raise validation_error - model = cls(**value) - # Reset fields set, this would have been done in Pydantic in __init__ - object.__setattr__(model, "__fields_set__", fields_set) - return model - elif cls.__config__.orm_mode: - return cls.from_orm(value) - elif cls.__custom_root_type__: - return cls.parse_obj(value) - else: - try: - value_as_dict = dict(value) - except (TypeError, ValueError) as e: - raise DictError() from e - return cls(**value_as_dict) - - # From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes - def _calculate_keys( - self, - include: Optional[Mapping[Union[int, str], Any]], - exclude: Optional[Mapping[Union[int, str], Any]], - exclude_unset: bool, - update: Optional[Dict[str, Any]] = None, - ) -> Optional[AbstractSet[str]]: - if include is None and exclude is None and not exclude_unset: - # Original in Pydantic: - # return None - # Updated to not return SQLAlchemy attributes - # Do not include relationships as that would easily lead to infinite - # recursion, or traversing the whole database - return self.__fields__.keys() # | self.__sqlmodel_relationships__.keys() - - keys: AbstractSet[str] - if exclude_unset: - keys = self.__fields_set__.copy() - else: - # Original in Pydantic: - # keys = self.__dict__.keys() - # Updated to not return SQLAlchemy attributes - # Do not include relationships as that would easily lead to infinite - # recursion, or traversing the whole database - keys = self.__fields__.keys() # | self.__sqlmodel_relationships__.keys() - if include is not None: - keys &= include.keys() - - if update: - keys -= update.keys() - - if exclude: - keys -= {k for k, v in exclude.items() if _value_items_is_true(v)} - - return keys - @declared_attr # type: ignore def __tablename__(cls) -> str: return cls.__name__.lower() + if IS_PYDANTIC_V2: + + @classmethod + def model_validate( + cls: type[_TSQLModel], + obj: Any, + *, + strict: bool | None = None, + from_attributes: bool | None = None, + context: dict[str, Any] | None = None, + ) -> _TSQLModel: + # Somehow model validate doesn't call __init__ so it would remove our init logic + validated = super().model_validate( + obj, strict=strict, from_attributes=from_attributes, context=context + ) + return cls(**validated.model_dump(exclude_unset=True)) + + else: + + @classmethod + def from_orm( + cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None + ) -> _TSQLModel: + # Duplicated from Pydantic + if not cls.__config__.orm_mode: # noqa + raise ConfigError( + "You must have the config attribute orm_mode=True to use from_orm" + ) + obj = ( + {ROOT_KEY: obj} + if cls.__custom_root_type__ # noqa + else cls._decompose_class(obj) # noqa + ) + # SQLModel, support update dict + if update is not None: + obj = {**obj, **update} + # End SQLModel support dict + if not getattr(cls.__config__, "table", False): # noqa + # If not table, normal Pydantic code + m: _TSQLModel = cls.__new__(cls) + else: + # If table, create the new instance normally to make SQLAlchemy create + # the _sa_instance_state attribute + m = cls() + values, fields_set, validation_error = validate_model(cls, obj) + if validation_error: + raise validation_error + # Updated to trigger SQLAlchemy internal handling + if not getattr(cls.__config__, "table", False): # noqa + object.__setattr__(m, "__dict__", values) + else: + for key, value in values.items(): + setattr(m, key, value) + # Continue with standard Pydantic logic + object.__setattr__(m, "__fields_set__", fields_set) + m._init_private_attributes() # noqa + return m + + @classmethod + def parse_obj( + cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None + ) -> _TSQLModel: + obj = cls._enforce_dict_if_root(obj) # noqa + # SQLModel, support update dict + if update is not None: + obj = {**obj, **update} + # End SQLModel support dict + return super().parse_obj(obj) + + # From Pydantic, override to enforce validation with dict + @classmethod + def validate(cls: Type[_TSQLModel], value: Any) -> _TSQLModel: + if isinstance(value, cls): + return ( + value.copy() if cls.__config__.copy_on_model_validation else value # noqa + ) -def _is_field_noneable(field: ModelField) -> bool: - if not field.required: - # Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947) - return field.allow_none and ( - field.shape != SHAPE_SINGLETON or not field.sub_fields - ) - return False + value = cls._enforce_dict_if_root(value) + if isinstance(value, dict): + values, fields_set, validation_error = validate_model(cls, value) + if validation_error: + raise validation_error + model = cls(**value) + # Reset fields set, this would have been done in Pydantic in __init__ + object.__setattr__(model, "__fields_set__", fields_set) + return model + elif cls.__config__.orm_mode: # noqa + return cls.from_orm(value) + elif cls.__custom_root_type__: # noqa + return cls.parse_obj(value) + else: + try: + value_as_dict = dict(value) + except (TypeError, ValueError) as e: + raise DictError() from e + return cls(**value_as_dict) + + # From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes + def _calculate_keys( + self, + include: Optional[Mapping[Union[int, str], Any]], + exclude: Optional[Mapping[Union[int, str], Any]], + exclude_unset: bool, + update: Optional[Dict[str, Any]] = None, + ) -> Optional[AbstractSet[str]]: + if include is None and exclude is None and not exclude_unset: + # Original in Pydantic: + # return None + # Updated to not return SQLAlchemy attributes + # Do not include relationships as that would easily lead to infinite + # recursion, or traversing the whole database + return ( + self.__fields__.keys() # noqa + ) # | self.__sqlmodel_relationships__.keys() + + keys: AbstractSet[str] + if exclude_unset: + keys = self.__fields_set__.copy() # noqa + else: + # Original in Pydantic: + # keys = self.__dict__.keys() + # Updated to not return SQLAlchemy attributes + # Do not include relationships as that would easily lead to infinite + # recursion, or traversing the whole database + keys = ( + self.__fields__.keys() # noqa + ) # | self.__sqlmodel_relationships__.keys() + if include is not None: + keys &= include.keys() + + if update: + keys -= update.keys() + + if exclude: + keys -= {k for k, v in exclude.items() if _value_items_is_true(v)} + + return keys diff --git a/sqlmodel/sql/expression.py.jinja2 b/sqlmodel/sql/expression.py.jinja2 index 26d12a0395..b3acb22c95 100644 --- a/sqlmodel/sql/expression.py.jinja2 +++ b/sqlmodel/sql/expression.py.jinja2 @@ -91,4 +91,4 @@ def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: def col(column_expression: Any) -> ColumnClause: # type: ignore if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)): raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}") - return column_expression + return column_expression \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 2b8e5fc29e..020b33a566 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ import pytest from pydantic import BaseModel from sqlmodel import SQLModel +from sqlmodel.compat import IS_PYDANTIC_V2 from sqlmodel.main import default_registry top_level_path = Path(__file__).resolve().parent.parent @@ -67,3 +68,7 @@ def new_print(*args): calls.append(data) return new_print + + +needs_pydanticv2 = pytest.mark.skipif(not IS_PYDANTIC_V2, reason="requires Pydantic v2") +needs_pydanticv1 = pytest.mark.skipif(IS_PYDANTIC_V2, reason="requires Pydantic v1") diff --git a/tests/test_enums.py b/tests/test_enums.py index 194bdefea1..07a04c686e 100644 --- a/tests/test_enums.py +++ b/tests/test_enums.py @@ -5,6 +5,8 @@ from sqlalchemy.sql.type_api import TypeEngine from sqlmodel import Field, SQLModel +from .conftest import needs_pydanticv1, needs_pydanticv2 + """ Tests related to Enums @@ -72,7 +74,8 @@ def test_sqlite_ddl_sql(capsys): assert "CREATE TYPE" not in captured.out -def test_json_schema_flat_model(): +@needs_pydanticv1 +def test_json_schema_flat_model_pydantic_v1(): assert FlatModel.schema() == { "title": "FlatModel", "type": "object", @@ -92,7 +95,8 @@ def test_json_schema_flat_model(): } -def test_json_schema_inherit_model(): +@needs_pydanticv1 +def test_json_schema_inherit_model_pydantic_v1(): assert InheritModel.schema() == { "title": "InheritModel", "type": "object", @@ -110,3 +114,42 @@ def test_json_schema_inherit_model(): } }, } + + +@needs_pydanticv2 +def test_json_schema_flat_model_pydantic_v2(): + assert FlatModel.model_json_schema() == { + "title": "FlatModel", + "type": "object", + "properties": { + "id": {"default": None, "format": "uuid", "title": "Id", "type": "string"}, + "enum_field": {"allOf": [{"$ref": "#/$defs/MyEnum1"}], "default": None}, + }, + "$defs": { + "MyEnum1": { + "title": "MyEnum1", + "enum": ["A", "B"], + "type": "string", + } + }, + } + + +@needs_pydanticv2 +def test_json_schema_inherit_model_pydantic_v2(): + assert InheritModel.model_json_schema() == { + "title": "InheritModel", + "type": "object", + "properties": { + "id": {"title": "Id", "type": "string", "format": "uuid"}, + "enum_field": {"$ref": "#/$defs/MyEnum2"}, + }, + "required": ["id", "enum_field"], + "$defs": { + "MyEnum2": { + "title": "MyEnum2", + "enum": ["C", "D"], + "type": "string", + } + }, + } diff --git a/tests/test_instance_no_args.py b/tests/test_instance_no_args.py index 14d560628b..e54e8163b3 100644 --- a/tests/test_instance_no_args.py +++ b/tests/test_instance_no_args.py @@ -1,11 +1,16 @@ from typing import Optional +import pytest +from pydantic import ValidationError from sqlalchemy import create_engine, select from sqlalchemy.orm import Session from sqlmodel import Field, SQLModel +from .conftest import needs_pydanticv1, needs_pydanticv2 -def test_allow_instantiation_without_arguments(clear_sqlmodel): + +@needs_pydanticv1 +def test_allow_instantiation_without_arguments_pydantic_v1(clear_sqlmodel): class Item(SQLModel): id: Optional[int] = Field(default=None, primary_key=True) name: str @@ -25,3 +30,33 @@ class Config: assert len(result) == 1 assert isinstance(item.id, int) SQLModel.metadata.clear() + + +def test_not_allow_instantiation_without_arguments_if_not_table(): + class Item(SQLModel): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + description: Optional[str] = None + + with pytest.raises(ValidationError): + Item() + + +@needs_pydanticv2 +def test_allow_instantiation_without_arguments_pydnatic_v2(clear_sqlmodel): + class Item(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + description: Optional[str] = None + + engine = create_engine("sqlite:///:memory:") + SQLModel.metadata.create_all(engine) + with Session(engine) as db: + item = Item() + item.name = "Rick" + db.add(item) + db.commit() + result = db.execute(select(Item)).scalars().all() + assert len(result) == 1 + assert isinstance(item.id, int) + SQLModel.metadata.clear() diff --git a/tests/test_missing_type.py b/tests/test_missing_type.py index 2185fa43e9..dc31f053ec 100644 --- a/tests/test_missing_type.py +++ b/tests/test_missing_type.py @@ -1,11 +1,12 @@ from typing import Optional import pytest +from pydantic import BaseModel from sqlmodel import Field, SQLModel def test_missing_sql_type(): - class CustomType: + class CustomType(BaseModel): @classmethod def __get_validators__(cls): yield cls.validate diff --git a/tests/test_nullable.py b/tests/test_nullable.py index 1c8b37b218..a40bb5b5f0 100644 --- a/tests/test_nullable.py +++ b/tests/test_nullable.py @@ -58,7 +58,7 @@ class Hero(SQLModel, table=True): ][0] assert "primary_key INTEGER NOT NULL," in create_table_log assert "required_value VARCHAR NOT NULL," in create_table_log - assert "optional_default_ellipsis VARCHAR NOT NULL," in create_table_log + assert "optional_default_ellipsis VARCHAR," in create_table_log assert "optional_default_none VARCHAR," in create_table_log assert "optional_non_nullable VARCHAR NOT NULL," in create_table_log assert "optional_nullable VARCHAR," in create_table_log diff --git a/tests/test_validation.py b/tests/test_validation.py index ad60fcb945..3265922070 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,12 +1,14 @@ from typing import Optional import pytest -from pydantic import validator from pydantic.error_wrappers import ValidationError from sqlmodel import SQLModel +from .conftest import needs_pydanticv1, needs_pydanticv2 -def test_validation(clear_sqlmodel): + +@needs_pydanticv1 +def test_validation_pydantic_v1(clear_sqlmodel): """Test validation of implicit and explicit None values. # For consistency with pydantic, validators are not to be called on @@ -16,6 +18,7 @@ def test_validation(clear_sqlmodel): https://github.com/samuelcolvin/pydantic/issues/1223 """ + from pydantic import validator class Hero(SQLModel): name: Optional[str] = None @@ -31,3 +34,32 @@ def reject_none(cls, v): with pytest.raises(ValidationError): Hero.validate({"name": None, "age": 25}) + + +@needs_pydanticv2 +def test_validation_pydantic_v2(clear_sqlmodel): + """Test validation of implicit and explicit None values. + + # For consistency with pydantic, validators are not to be called on + # arguments that are not explicitly provided. + + https://github.com/tiangolo/sqlmodel/issues/230 + https://github.com/samuelcolvin/pydantic/issues/1223 + + """ + from pydantic import field_validator + + class Hero(SQLModel): + name: Optional[str] = None + secret_name: Optional[str] = None + age: Optional[int] = None + + @field_validator("name", "secret_name", "age") + def reject_none(cls, v): + assert v is not None + return v + + Hero.model_validate({"age": 25}) + + with pytest.raises(ValidationError): + Hero.model_validate({"name": None, "age": 25})