diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 201abc7c22..3b7cfaf95a 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -20,11 +20,8 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version:
- - "3.7"
- - "3.8"
- - "3.9"
- - "3.10"
+ python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
+ pydantic-version: ["pydantic-v1", "pydantic-v2"]
fail-fast: false
steps:
@@ -55,8 +52,12 @@ jobs:
- name: Install Dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: python -m poetry install
- - name: Lint
- run: python -m poetry run bash scripts/lint.sh
+ - name: Install Pydantic v1
+ if: matrix.pydantic-version == 'pydantic-v1'
+ run: pip install "pydantic>=1.10.0,<2.0.0"
+ - name: Install Pydantic v2
+ if: matrix.pydantic-version == 'pydantic-v2'
+ run: pip install "pydantic>=2.0.2,<3.0.0"
- run: mkdir coverage
- name: Test
run: python -m poetry run bash scripts/test.sh
@@ -68,6 +69,8 @@ jobs:
with:
name: coverage
path: coverage
+ - name: Lint
+ run: python -m poetry run bash scripts/lint.sh
coverage-combine:
needs:
- test
diff --git a/docs/tutorial/fastapi/multiple-models.md b/docs/tutorial/fastapi/multiple-models.md
index 6845b9862d..4ea24c6752 100644
--- a/docs/tutorial/fastapi/multiple-models.md
+++ b/docs/tutorial/fastapi/multiple-models.md
@@ -174,13 +174,13 @@ Now we use the type annotation `HeroCreate` for the request JSON data in the `he
# Code below omitted 👇
```
-Then we create a new `Hero` (this is the actual **table** model that saves things to the database) using `Hero.from_orm()`.
+Then we create a new `Hero` (this is the actual **table** model that saves things to the database) using `Hero.model_validate()`.
-The method `.from_orm()` reads data from another object with attributes and creates a new instance of this class, in this case `Hero`.
+The method `.model_validate()` reads data from another object with attributes and creates a new instance of this class, in this case `Hero`.
The alternative is `Hero.parse_obj()` that reads data from a dictionary.
-But as in this case, we have a `HeroCreate` instance in the `hero` variable. This is an object with attributes, so we use `.from_orm()` to read those attributes.
+But as in this case, we have a `HeroCreate` instance in the `hero` variable. This is an object with attributes, so we use `.model_validate()` to read those attributes.
With this, we create a new `Hero` instance (the one for the database) and put it in the variable `db_hero` from the data in the `hero` variable that is the `HeroCreate` instance we received from the request.
diff --git a/docs/tutorial/index.md b/docs/tutorial/index.md
index 74107776c2..54e1147d68 100644
--- a/docs/tutorial/index.md
+++ b/docs/tutorial/index.md
@@ -64,6 +64,8 @@ $ cd sqlmodel-tutorial
Make sure you have an officially supported version of Python.
+Currently it is **Python 3.7** and above (Python 3.6 was already deprecated).
+
You can check which version you have with:
@@ -79,9 +81,11 @@ There's a chance that you have multiple Python versions installed.
You might want to try with the specific versions, for example with:
+* `python3.11`
* `python3.10`
* `python3.9`
* `python3.8`
+* `python3.7`
The code would look like this:
diff --git a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py
index 3f0602e4b4..f305f75194 100644
--- a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py
+++ b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py
@@ -2,6 +2,7 @@
from fastapi import Depends, FastAPI, HTTPException, Query
from sqlmodel import Field, Session, SQLModel, create_engine, select
+from sqlmodel.compat import IS_PYDANTIC_V2
class HeroBase(SQLModel):
@@ -54,7 +55,10 @@ def on_startup():
@app.post("/heroes/", response_model=HeroRead)
def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate):
- db_hero = Hero.from_orm(hero)
+ if IS_PYDANTIC_V2:
+ db_hero = Hero.model_validate(hero)
+ else:
+ db_hero = Hero.from_orm(hero)
session.add(db_hero)
session.commit()
session.refresh(db_hero)
diff --git a/docs_src/tutorial/fastapi/delete/tutorial001.py b/docs_src/tutorial/fastapi/delete/tutorial001.py
index 3069fc5e87..f186c42b2b 100644
--- a/docs_src/tutorial/fastapi/delete/tutorial001.py
+++ b/docs_src/tutorial/fastapi/delete/tutorial001.py
@@ -2,6 +2,7 @@
from fastapi import FastAPI, HTTPException, Query
from sqlmodel import Field, Session, SQLModel, create_engine, select
+from sqlmodel.compat import IS_PYDANTIC_V2
class HeroBase(SQLModel):
@@ -50,7 +51,10 @@ def on_startup():
@app.post("/heroes/", response_model=HeroRead)
def create_hero(hero: HeroCreate):
with Session(engine) as session:
- db_hero = Hero.from_orm(hero)
+ if IS_PYDANTIC_V2:
+ db_hero = Hero.model_validate(hero)
+ else:
+ db_hero = Hero.from_orm(hero)
session.add(db_hero)
session.commit()
session.refresh(db_hero)
diff --git a/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py b/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py
index 2b8739ca70..6701355f17 100644
--- a/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py
+++ b/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py
@@ -2,6 +2,7 @@
from fastapi import FastAPI, HTTPException, Query
from sqlmodel import Field, Session, SQLModel, create_engine, select
+from sqlmodel.compat import IS_PYDANTIC_V2
class HeroBase(SQLModel):
@@ -44,7 +45,10 @@ def on_startup():
@app.post("/heroes/", response_model=HeroRead)
def create_hero(hero: HeroCreate):
with Session(engine) as session:
- db_hero = Hero.from_orm(hero)
+ if IS_PYDANTIC_V2:
+ db_hero = Hero.model_validate(hero)
+ else:
+ db_hero = Hero.from_orm(hero)
session.add(db_hero)
session.commit()
session.refresh(db_hero)
diff --git a/docs_src/tutorial/fastapi/multiple_models/tutorial001.py b/docs_src/tutorial/fastapi/multiple_models/tutorial001.py
index df20123333..0ceed94ca1 100644
--- a/docs_src/tutorial/fastapi/multiple_models/tutorial001.py
+++ b/docs_src/tutorial/fastapi/multiple_models/tutorial001.py
@@ -2,6 +2,7 @@
from fastapi import FastAPI
from sqlmodel import Field, Session, SQLModel, create_engine, select
+from sqlmodel.compat import IS_PYDANTIC_V2
class Hero(SQLModel, table=True):
@@ -46,7 +47,10 @@ def on_startup():
@app.post("/heroes/", response_model=HeroRead)
def create_hero(hero: HeroCreate):
with Session(engine) as session:
- db_hero = Hero.from_orm(hero)
+ if IS_PYDANTIC_V2:
+ db_hero = Hero.model_validate(hero)
+ else:
+ db_hero = Hero.from_orm(hero)
session.add(db_hero)
session.commit()
session.refresh(db_hero)
diff --git a/docs_src/tutorial/fastapi/multiple_models/tutorial002.py b/docs_src/tutorial/fastapi/multiple_models/tutorial002.py
index 392c2c5829..d92745a339 100644
--- a/docs_src/tutorial/fastapi/multiple_models/tutorial002.py
+++ b/docs_src/tutorial/fastapi/multiple_models/tutorial002.py
@@ -2,6 +2,7 @@
from fastapi import FastAPI
from sqlmodel import Field, Session, SQLModel, create_engine, select
+from sqlmodel.compat import IS_PYDANTIC_V2
class HeroBase(SQLModel):
@@ -44,7 +45,10 @@ def on_startup():
@app.post("/heroes/", response_model=HeroRead)
def create_hero(hero: HeroCreate):
with Session(engine) as session:
- db_hero = Hero.from_orm(hero)
+ if IS_PYDANTIC_V2:
+ db_hero = Hero.model_validate(hero)
+ else:
+ db_hero = Hero.from_orm(hero)
session.add(db_hero)
session.commit()
session.refresh(db_hero)
diff --git a/docs_src/tutorial/fastapi/read_one/tutorial001.py b/docs_src/tutorial/fastapi/read_one/tutorial001.py
index 4d66e471a5..aa805d6c8f 100644
--- a/docs_src/tutorial/fastapi/read_one/tutorial001.py
+++ b/docs_src/tutorial/fastapi/read_one/tutorial001.py
@@ -2,6 +2,7 @@
from fastapi import FastAPI, HTTPException
from sqlmodel import Field, Session, SQLModel, create_engine, select
+from sqlmodel.compat import IS_PYDANTIC_V2
class HeroBase(SQLModel):
@@ -44,7 +45,10 @@ def on_startup():
@app.post("/heroes/", response_model=HeroRead)
def create_hero(hero: HeroCreate):
with Session(engine) as session:
- db_hero = Hero.from_orm(hero)
+ if IS_PYDANTIC_V2:
+ db_hero = Hero.model_validate(hero)
+ else:
+ db_hero = Hero.from_orm(hero)
session.add(db_hero)
session.commit()
session.refresh(db_hero)
diff --git a/docs_src/tutorial/fastapi/relationships/tutorial001.py b/docs_src/tutorial/fastapi/relationships/tutorial001.py
index 8477e4a2a0..dfcedaf881 100644
--- a/docs_src/tutorial/fastapi/relationships/tutorial001.py
+++ b/docs_src/tutorial/fastapi/relationships/tutorial001.py
@@ -2,6 +2,7 @@
from fastapi import Depends, FastAPI, HTTPException, Query
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select
+from sqlmodel.compat import IS_PYDANTIC_V2
class TeamBase(SQLModel):
@@ -92,7 +93,10 @@ def on_startup():
@app.post("/heroes/", response_model=HeroRead)
def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate):
- db_hero = Hero.from_orm(hero)
+ if IS_PYDANTIC_V2:
+ db_hero = Hero.model_validate(hero)
+ else:
+ db_hero = Hero.from_orm(hero)
session.add(db_hero)
session.commit()
session.refresh(db_hero)
@@ -146,7 +150,10 @@ def delete_hero(*, session: Session = Depends(get_session), hero_id: int):
@app.post("/teams/", response_model=TeamRead)
def create_team(*, session: Session = Depends(get_session), team: TeamCreate):
- db_team = Team.from_orm(team)
+ if IS_PYDANTIC_V2:
+ db_team = Team.model_validate(team)
+ else:
+ db_team = Team.from_orm(team)
session.add(db_team)
session.commit()
session.refresh(db_team)
diff --git a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py
index 3f0602e4b4..f305f75194 100644
--- a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py
+++ b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py
@@ -2,6 +2,7 @@
from fastapi import Depends, FastAPI, HTTPException, Query
from sqlmodel import Field, Session, SQLModel, create_engine, select
+from sqlmodel.compat import IS_PYDANTIC_V2
class HeroBase(SQLModel):
@@ -54,7 +55,10 @@ def on_startup():
@app.post("/heroes/", response_model=HeroRead)
def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate):
- db_hero = Hero.from_orm(hero)
+ if IS_PYDANTIC_V2:
+ db_hero = Hero.model_validate(hero)
+ else:
+ db_hero = Hero.from_orm(hero)
session.add(db_hero)
session.commit()
session.refresh(db_hero)
diff --git a/docs_src/tutorial/fastapi/teams/tutorial001.py b/docs_src/tutorial/fastapi/teams/tutorial001.py
index 1da0dad8a2..46ea0f933c 100644
--- a/docs_src/tutorial/fastapi/teams/tutorial001.py
+++ b/docs_src/tutorial/fastapi/teams/tutorial001.py
@@ -2,6 +2,7 @@
from fastapi import Depends, FastAPI, HTTPException, Query
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select
+from sqlmodel.compat import IS_PYDANTIC_V2
class TeamBase(SQLModel):
@@ -83,7 +84,10 @@ def on_startup():
@app.post("/heroes/", response_model=HeroRead)
def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate):
- db_hero = Hero.from_orm(hero)
+ if IS_PYDANTIC_V2:
+ db_hero = Hero.model_validate(hero)
+ else:
+ db_hero = Hero.from_orm(hero)
session.add(db_hero)
session.commit()
session.refresh(db_hero)
@@ -137,7 +141,10 @@ def delete_hero(*, session: Session = Depends(get_session), hero_id: int):
@app.post("/teams/", response_model=TeamRead)
def create_team(*, session: Session = Depends(get_session), team: TeamCreate):
- db_team = Team.from_orm(team)
+ if IS_PYDANTIC_V2:
+ db_team = Team.model_validate(team)
+ else:
+ db_team = Team.from_orm(team)
session.add(db_team)
session.commit()
session.refresh(db_team)
diff --git a/docs_src/tutorial/fastapi/update/tutorial001.py b/docs_src/tutorial/fastapi/update/tutorial001.py
index bb98efd581..93dfa7496a 100644
--- a/docs_src/tutorial/fastapi/update/tutorial001.py
+++ b/docs_src/tutorial/fastapi/update/tutorial001.py
@@ -2,6 +2,7 @@
from fastapi import FastAPI, HTTPException, Query
from sqlmodel import Field, Session, SQLModel, create_engine, select
+from sqlmodel.compat import IS_PYDANTIC_V2
class HeroBase(SQLModel):
@@ -50,7 +51,10 @@ def on_startup():
@app.post("/heroes/", response_model=HeroRead)
def create_hero(hero: HeroCreate):
with Session(engine) as session:
- db_hero = Hero.from_orm(hero)
+ if IS_PYDANTIC_V2:
+ db_hero = Hero.model_validate(hero)
+ else:
+ db_hero = Hero.from_orm(hero)
session.add(db_hero)
session.commit()
session.refresh(db_hero)
diff --git a/docs_src/tutorial/many_to_many/tutorial003.py b/docs_src/tutorial/many_to_many/tutorial003.py
index 1e03c4af89..cec6e56560 100644
--- a/docs_src/tutorial/many_to_many/tutorial003.py
+++ b/docs_src/tutorial/many_to_many/tutorial003.py
@@ -3,25 +3,12 @@
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select
-class HeroTeamLink(SQLModel, table=True):
- team_id: Optional[int] = Field(
- default=None, foreign_key="team.id", primary_key=True
- )
- hero_id: Optional[int] = Field(
- default=None, foreign_key="hero.id", primary_key=True
- )
- is_training: bool = False
-
- team: "Team" = Relationship(back_populates="hero_links")
- hero: "Hero" = Relationship(back_populates="team_links")
-
-
class Team(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(index=True)
headquarters: str
- hero_links: List[HeroTeamLink] = Relationship(back_populates="team")
+ hero_links: List["HeroTeamLink"] = Relationship(back_populates="team")
class Hero(SQLModel, table=True):
@@ -30,7 +17,20 @@ class Hero(SQLModel, table=True):
secret_name: str
age: Optional[int] = Field(default=None, index=True)
- team_links: List[HeroTeamLink] = Relationship(back_populates="hero")
+ team_links: List["HeroTeamLink"] = Relationship(back_populates="hero")
+
+
+class HeroTeamLink(SQLModel, table=True):
+ team_id: Optional[int] = Field(
+ default=None, foreign_key="team.id", primary_key=True
+ )
+ hero_id: Optional[int] = Field(
+ default=None, foreign_key="hero.id", primary_key=True
+ )
+ is_training: bool = False
+
+ team: "Team" = Relationship(back_populates="hero_links")
+ hero: "Hero" = Relationship(back_populates="team_links")
sqlite_file_name = "database.db"
diff --git a/docs_src/tutorial/relationship_attributes/back_populates/tutorial003.py b/docs_src/tutorial/relationship_attributes/back_populates/tutorial003.py
index 98e197002e..8d91a0bc25 100644
--- a/docs_src/tutorial/relationship_attributes/back_populates/tutorial003.py
+++ b/docs_src/tutorial/relationship_attributes/back_populates/tutorial003.py
@@ -3,6 +3,21 @@
from sqlmodel import Field, Relationship, SQLModel, create_engine
+class Hero(SQLModel, table=True):
+ id: Optional[int] = Field(default=None, primary_key=True)
+ name: str = Field(index=True)
+ secret_name: str
+ age: Optional[int] = Field(default=None, index=True)
+
+ team_id: Optional[int] = Field(default=None, foreign_key="team.id")
+ team: Optional["Team"] = Relationship(back_populates="heroes")
+
+ weapon_id: Optional[int] = Field(default=None, foreign_key="weapon.id")
+ weapon: Optional["Weapon"] = Relationship(back_populates="hero")
+
+ powers: List["Power"] = Relationship(back_populates="hero")
+
+
class Weapon(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(index=True)
@@ -26,21 +41,6 @@ class Team(SQLModel, table=True):
heroes: List["Hero"] = Relationship(back_populates="team")
-class Hero(SQLModel, table=True):
- id: Optional[int] = Field(default=None, primary_key=True)
- name: str = Field(index=True)
- secret_name: str
- age: Optional[int] = Field(default=None, index=True)
-
- team_id: Optional[int] = Field(default=None, foreign_key="team.id")
- team: Optional[Team] = Relationship(back_populates="heroes")
-
- weapon_id: Optional[int] = Field(default=None, foreign_key="weapon.id")
- weapon: Optional[Weapon] = Relationship(back_populates="hero")
-
- powers: List[Power] = Relationship(back_populates="hero")
-
-
sqlite_file_name = "database.db"
sqlite_url = f"sqlite:///{sqlite_file_name}"
diff --git a/pyproject.toml b/pyproject.toml
index 23fa79bf31..f104631655 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -21,6 +21,7 @@ classifiers = [
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
"Topic :: Database",
"Topic :: Database :: Database Engines/Servers",
"Topic :: Internet",
@@ -45,7 +46,7 @@ pillow = "^9.3.0"
cairosvg = "^2.5.2"
mdx-include = "^1.4.1"
coverage = {extras = ["toml"], version = ">=6.2,<8.0"}
-fastapi = "^0.68.1"
+fastapi = "^0.100.0"
requests = "^2.26.0"
ruff = "^0.1.2"
diff --git a/sqlmodel/__init__.py b/sqlmodel/__init__.py
index 495ac9c8a8..7e20e1ba41 100644
--- a/sqlmodel/__init__.py
+++ b/sqlmodel/__init__.py
@@ -21,7 +21,6 @@
from sqlalchemy.schema import PrimaryKeyConstraint as PrimaryKeyConstraint
from sqlalchemy.schema import Sequence as Sequence
from sqlalchemy.schema import Table as Table
-from sqlalchemy.schema import ThreadLocalMetaData as ThreadLocalMetaData
from sqlalchemy.schema import UniqueConstraint as UniqueConstraint
from sqlalchemy.sql import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT
from sqlalchemy.sql import (
@@ -31,6 +30,7 @@
from sqlalchemy.sql import (
LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL,
)
+from sqlalchemy.sql import Subquery as Subquery
from sqlalchemy.sql import alias as alias
from sqlalchemy.sql import all_ as all_
from sqlalchemy.sql import and_ as and_
@@ -71,7 +71,6 @@
from sqlalchemy.sql import outerjoin as outerjoin
from sqlalchemy.sql import outparam as outparam
from sqlalchemy.sql import over as over
-from sqlalchemy.sql import subquery as subquery
from sqlalchemy.sql import table as table
from sqlalchemy.sql import tablesample as tablesample
from sqlalchemy.sql import text as text
diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py
new file mode 100644
index 0000000000..dbd22053a8
--- /dev/null
+++ b/sqlmodel/compat.py
@@ -0,0 +1,416 @@
+import ipaddress
+import uuid
+from datetime import date, datetime, time, timedelta
+from decimal import Decimal
+from enum import Enum
+from pathlib import Path
+from types import NoneType
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ ForwardRef,
+ Optional,
+ Sequence,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+ get_args,
+ get_origin,
+)
+
+from pydantic import VERSION as PYDANTIC_VERSION
+from sqlalchemy import (
+ Boolean,
+ Column,
+ Date,
+ DateTime,
+ Float,
+ ForeignKey,
+ Integer,
+ Interval,
+ Numeric,
+)
+from sqlalchemy import Enum as sa_Enum
+from sqlalchemy.sql.sqltypes import LargeBinary, Time
+
+from .sql.sqltypes import GUID, AutoString
+
+IS_PYDANTIC_V2 = int(PYDANTIC_VERSION.split(".")[0]) >= 2
+
+
+if IS_PYDANTIC_V2:
+ from pydantic import ConfigDict as PydanticModelConfig
+ from pydantic._internal._fields import PydanticMetadata
+ from pydantic._internal._model_construction import ModelMetaclass
+ from pydantic_core import PydanticUndefined as PydanticUndefined # noqa
+ from pydantic_core import PydanticUndefinedType as PydanticUndefinedType
+else:
+ from pydantic import BaseConfig as PydanticModelConfig
+ from pydantic.fields import SHAPE_SINGLETON, ModelField
+ from pydantic.fields import Undefined as PydanticUndefined # noqa
+ from pydantic.fields import UndefinedType as PydanticUndefinedType
+ from pydantic.main import ModelMetaclass as ModelMetaclass
+ from pydantic.typing import resolve_annotations
+
+if TYPE_CHECKING:
+ from .main import FieldInfo, RelationshipInfo, SQLModel, SQLModelMetaclass
+
+
+NoArgAnyCallable = Callable[[], Any]
+T = TypeVar("T")
+InstanceOrType = Union[T, Type[T]]
+
+if IS_PYDANTIC_V2:
+
+ class SQLModelConfig(PydanticModelConfig, total=False):
+ table: Optional[bool]
+ registry: Optional[Any]
+
+else:
+
+ class SQLModelConfig(PydanticModelConfig):
+ table: Optional[bool] = None
+ registry: Optional[Any] = None
+
+
+# Inspired from https://github.com/roman-right/beanie/blob/main/beanie/odm/utils/pydantic.py
+def get_model_config(model: type) -> Optional[SQLModelConfig]:
+ if IS_PYDANTIC_V2:
+ return getattr(model, "model_config", None)
+ else:
+ return getattr(model, "__config__", None)
+
+
+def get_config_value(
+ model: InstanceOrType["SQLModel"], parameter: str, default: Any = None
+) -> Any:
+ if IS_PYDANTIC_V2:
+ return model.model_config.get(parameter, default)
+ else:
+ return getattr(model.__config__, parameter, default)
+
+
+def set_config_value(
+ model: InstanceOrType["SQLModel"],
+ parameter: str,
+ value: Any,
+ v1_parameter: Optional[str] = None,
+) -> None:
+ if IS_PYDANTIC_V2:
+ model.model_config[parameter] = value # type: ignore
+ else:
+ setattr(model.__config__, v1_parameter or parameter, value) # type: ignore
+
+
+def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]:
+ if IS_PYDANTIC_V2:
+ return model.model_fields # type: ignore
+ else:
+ return model.__fields__ # type: ignore
+
+
+def get_fields_set(model: InstanceOrType["SQLModel"]) -> set[str]:
+ if IS_PYDANTIC_V2:
+ return model.__pydantic_fields_set__ # type: ignore
+ else:
+ return model.__fields_set__ # type: ignore
+
+
+def set_fields_set(
+ new_object: InstanceOrType["SQLModel"], fields: set["FieldInfo"]
+) -> None:
+ if IS_PYDANTIC_V2:
+ object.__setattr__(new_object, "__pydantic_fields_set__", fields)
+ else:
+ object.__setattr__(new_object, "__fields_set__", fields)
+
+
+def set_attribute_mode(cls: Type["SQLModelMetaclass"]) -> None:
+ if IS_PYDANTIC_V2:
+ cls.model_config["read_from_attributes"] = True
+ else:
+ cls.__config__.read_with_orm_mode = True # type: ignore
+
+
+def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]:
+ if IS_PYDANTIC_V2:
+ return class_dict.get("__annotations__", {})
+ else:
+ return resolve_annotations(
+ class_dict.get("__annotations__", {}), class_dict.get("__module__", None)
+ )
+
+
+def class_dict_is_table(
+ class_dict: dict[str, Any], class_kwargs: dict[str, Any]
+) -> bool:
+ config: SQLModelConfig = {}
+ if IS_PYDANTIC_V2:
+ config = class_dict.get("model_config", {})
+ else:
+ config = class_dict.get("__config__", {})
+ config_table = config.get("table", PydanticUndefined)
+ if config_table is not PydanticUndefined:
+ return config_table # type: ignore
+ kw_table = class_kwargs.get("table", PydanticUndefined)
+ if kw_table is not PydanticUndefined:
+ return kw_table # type: ignore
+ return False
+
+
+def cls_is_table(cls: Type) -> bool:
+ if IS_PYDANTIC_V2:
+ config = getattr(cls, "model_config", None)
+ if not config:
+ return False
+ return config.get("table", False)
+ else:
+ config = getattr(cls, "__config__", None)
+ if not config:
+ return False
+ return getattr(config, "table", False)
+
+
+def get_relationship_to(
+ name: str,
+ rel_info: "RelationshipInfo",
+ annotation: Any,
+) -> Any:
+ if IS_PYDANTIC_V2:
+ relationship_to = get_origin(annotation)
+ # Direct relationships (e.g. 'Team' or Team) have None as an origin
+ if relationship_to is None:
+ relationship_to = annotation
+ # If Union (e.g. Optional), get the real field
+ elif relationship_to is Union:
+ relationship_to = get_args(annotation)[0]
+ # If a list, then also get the real field
+ elif relationship_to is list:
+ relationship_to = get_args(annotation)[0]
+ if isinstance(relationship_to, ForwardRef):
+ relationship_to = relationship_to.__forward_arg__
+ return relationship_to
+ else:
+ temp_field = ModelField.infer(
+ name=name,
+ value=rel_info,
+ annotation=annotation,
+ class_validators=None,
+ config=SQLModelConfig,
+ )
+ relationship_to = temp_field.type_
+ if isinstance(temp_field.type_, ForwardRef):
+ relationship_to = temp_field.type_.__forward_arg__
+ return relationship_to
+
+
+def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any]) -> None:
+ """
+ Pydantic v2 without required fields with no optionals cannot do empty initialisations.
+ This means we cannot do Table() and set fields later.
+ We go around this by adding a default to everything, being None
+
+ Args:
+ annotations: Dict[str, Any]: The annotations to provide to pydantic
+ class_dict: Dict[str, Any]: The class dict for the defaults
+ """
+ if IS_PYDANTIC_V2:
+ from .main import FieldInfo
+
+ # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything
+ for key in annotations.keys():
+ value = class_dict.get(key, PydanticUndefined)
+ if value is PydanticUndefined:
+ class_dict[key] = None
+ elif isinstance(value, FieldInfo):
+ if (
+ value.default in (PydanticUndefined, Ellipsis)
+ ) and value.default_factory is None:
+ # So we can check for nullable
+ value.default = None
+
+
+def _is_field_noneable(field: "FieldInfo") -> bool:
+ if IS_PYDANTIC_V2:
+ if getattr(field, "nullable", PydanticUndefined) is not PydanticUndefined:
+ return field.nullable # type: ignore
+ if not field.is_required():
+ if field.default is PydanticUndefined:
+ return False
+ if field.annotation is None or field.annotation is NoneType:
+ return True
+ if get_origin(field.annotation) is Union:
+ for base in get_args(field.annotation):
+ if base is NoneType:
+ return True
+ return False
+ return False
+ else:
+ if not field.required:
+ # Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947)
+ return field.allow_none and (
+ field.shape != SHAPE_SINGLETON or not field.sub_fields
+ )
+ return field.allow_none
+
+
+def get_sqlalchemy_type(field: Any) -> Any:
+ if IS_PYDANTIC_V2:
+ field_info = field
+ else:
+ field_info = field.field_info
+ sa_type = getattr(field_info, "sa_type", PydanticUndefined) # noqa: B009
+ if sa_type is not PydanticUndefined:
+ return sa_type
+
+ type_ = get_type_from_field(field)
+ metadata = get_field_metadata(field)
+
+ # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
+ if issubclass(type_, Enum):
+ return sa_Enum(type_)
+ if issubclass(type_, str):
+ max_length = getattr(metadata, "max_length", None)
+ if max_length:
+ return AutoString(length=max_length)
+ return AutoString
+ if issubclass(type_, float):
+ return Float
+ if issubclass(type_, bool):
+ return Boolean
+ if issubclass(type_, int):
+ return Integer
+ if issubclass(type_, datetime):
+ return DateTime
+ if issubclass(type_, date):
+ return Date
+ if issubclass(type_, timedelta):
+ return Interval
+ if issubclass(type_, time):
+ return Time
+ if issubclass(type_, bytes):
+ return LargeBinary
+ if issubclass(type_, Decimal):
+ return Numeric(
+ precision=getattr(metadata, "max_digits", None),
+ scale=getattr(metadata, "decimal_places", None),
+ )
+ if issubclass(type_, ipaddress.IPv4Address):
+ return AutoString
+ if issubclass(type_, ipaddress.IPv4Network):
+ return AutoString
+ if issubclass(type_, ipaddress.IPv6Address):
+ return AutoString
+ if issubclass(type_, ipaddress.IPv6Network):
+ return AutoString
+ if issubclass(type_, Path):
+ return AutoString
+ if issubclass(type_, uuid.UUID):
+ return GUID
+ raise ValueError(f"{type_} has no matching SQLAlchemy type")
+
+
+def get_type_from_field(field: Any) -> type:
+ if IS_PYDANTIC_V2:
+ type_: type | None = field.annotation
+ # Resolve Optional fields
+ if type_ is None:
+ raise ValueError("Missing field type")
+ origin = get_origin(type_)
+ if origin is None:
+ return type_
+ if origin is Union:
+ bases = get_args(type_)
+ if len(bases) > 2:
+ raise ValueError(
+ "Cannot have a (non-optional) union as a SQL alchemy field"
+ )
+ # Non optional unions are not allowed
+ if bases[0] is not NoneType and bases[1] is not NoneType:
+ raise ValueError(
+ "Cannot have a (non-optional) union as a SQL alchemy field"
+ )
+ # Optional unions are allowed
+ return bases[0] if bases[0] is not NoneType else bases[1]
+ return origin
+ else:
+ if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
+ return field.type_
+ raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")
+
+
+class FakeMetadata:
+ max_length: Optional[int] = None
+ max_digits: Optional[int] = None
+ decimal_places: Optional[int] = None
+
+
+def get_field_metadata(field: Any) -> Any:
+ if IS_PYDANTIC_V2:
+ for meta in field.metadata:
+ if isinstance(meta, PydanticMetadata):
+ return meta
+ return FakeMetadata()
+ else:
+ metadata = FakeMetadata()
+ metadata.max_length = field.field_info.max_length
+ metadata.max_digits = getattr(field.type_, "max_digits", None)
+ metadata.decimal_places = getattr(field.type_, "decimal_places", None)
+ return metadata
+
+
+def get_column_from_field(field: Any) -> Column: # type: ignore
+ if IS_PYDANTIC_V2:
+ field_info = field
+ else:
+ field_info = field.field_info
+ sa_column = getattr(field_info, "sa_column", PydanticUndefined)
+ if isinstance(sa_column, Column):
+ return sa_column
+ sa_type = get_sqlalchemy_type(field)
+ primary_key = getattr(field_info, "primary_key", PydanticUndefined)
+ if primary_key is PydanticUndefined:
+ primary_key = False
+ index = getattr(field_info, "index", PydanticUndefined)
+ if index is PydanticUndefined:
+ index = False
+ nullable = not primary_key and _is_field_noneable(field)
+ # Override derived nullability if the nullable property is set explicitly
+ # on the field
+ field_nullable = getattr(field_info, "nullable", PydanticUndefined) # noqa: B009
+ if field_nullable is not PydanticUndefined:
+ assert not isinstance(field_nullable, PydanticUndefinedType)
+ nullable = field_nullable
+ args = []
+ foreign_key = getattr(field_info, "foreign_key", PydanticUndefined)
+ if foreign_key is PydanticUndefined:
+ foreign_key = None
+ unique = getattr(field_info, "unique", PydanticUndefined)
+ if unique is PydanticUndefined:
+ unique = False
+ if foreign_key:
+ assert isinstance(foreign_key, str)
+ args.append(ForeignKey(foreign_key))
+ kwargs = {
+ "primary_key": primary_key,
+ "nullable": nullable,
+ "index": index,
+ "unique": unique,
+ }
+ sa_default = PydanticUndefined
+ if field_info.default_factory:
+ sa_default = field_info.default_factory
+ elif field_info.default is not PydanticUndefined:
+ sa_default = field_info.default
+ if sa_default is not PydanticUndefined:
+ kwargs["default"] = sa_default
+ sa_column_args = getattr(field_info, "sa_column_args", PydanticUndefined)
+ if sa_column_args is not PydanticUndefined:
+ args.extend(list(cast(Sequence[Any], sa_column_args)))
+ sa_column_kwargs = getattr(field_info, "sa_column_kwargs", PydanticUndefined)
+ if sa_column_kwargs is not PydanticUndefined:
+ kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
+ return Column(sa_type, *args, **kwargs) # type: ignore
diff --git a/sqlmodel/engine/create.py b/sqlmodel/engine/create.py
index b2d567b1b1..97481259e2 100644
--- a/sqlmodel/engine/create.py
+++ b/sqlmodel/engine/create.py
@@ -136,4 +136,4 @@ def create_engine(
if not isinstance(query_cache_size, _DefaultPlaceholder):
current_kwargs["query_cache_size"] = query_cache_size
current_kwargs.update(kwargs)
- return _create_engine(url, **current_kwargs) # type: ignore
+ return _create_engine(url, **current_kwargs)
diff --git a/sqlmodel/main.py b/sqlmodel/main.py
index 2b69dd2a75..cb008bb663 100644
--- a/sqlmodel/main.py
+++ b/sqlmodel/main.py
@@ -1,17 +1,10 @@
-import ipaddress
-import uuid
import weakref
-from datetime import date, datetime, time, timedelta
-from decimal import Decimal
-from enum import Enum
-from pathlib import Path
from typing import (
AbstractSet,
Any,
Callable,
ClassVar,
Dict,
- ForwardRef,
List,
Mapping,
Optional,
@@ -25,34 +18,43 @@
overload,
)
-from pydantic import BaseConfig, BaseModel
-from pydantic.errors import ConfigError, DictError
-from pydantic.fields import SHAPE_SINGLETON, ModelField, Undefined, UndefinedType
+from pydantic import BaseModel
from pydantic.fields import FieldInfo as PydanticFieldInfo
-from pydantic.main import ModelMetaclass, validate_model
-from pydantic.typing import NoArgAnyCallable, resolve_annotations
-from pydantic.utils import ROOT_KEY, Representation
+from pydantic.utils import Representation
from sqlalchemy import (
- Boolean,
Column,
- Date,
- DateTime,
- Float,
- ForeignKey,
- Integer,
- Interval,
- Numeric,
inspect,
)
-from sqlalchemy import Enum as sa_Enum
from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship
from sqlalchemy.orm.attributes import set_attribute
from sqlalchemy.orm.decl_api import DeclarativeMeta
from sqlalchemy.orm.instrumentation import is_instrumented
from sqlalchemy.sql.schema import MetaData
-from sqlalchemy.sql.sqltypes import LargeBinary, Time
-from .sql.sqltypes import GUID, AutoString
+from .compat import (
+ IS_PYDANTIC_V2,
+ ModelMetaclass,
+ NoArgAnyCallable,
+ PydanticModelConfig,
+ PydanticUndefined,
+ PydanticUndefinedType,
+ SQLModelConfig,
+ class_dict_is_table,
+ cls_is_table,
+ get_annotations,
+ get_column_from_field,
+ get_config_value,
+ get_model_fields,
+ get_relationship_to,
+ set_config_value,
+ set_empty_defaults,
+ set_fields_set,
+)
+
+if not IS_PYDANTIC_V2:
+ from pydantic.errors import ConfigError, DictError
+ from pydantic.main import validate_model
+ from pydantic.utils import ROOT_KEY
_T = TypeVar("_T")
@@ -68,50 +70,50 @@ def __dataclass_transform__(
class FieldInfo(PydanticFieldInfo):
- def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
+ def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None:
primary_key = kwargs.pop("primary_key", False)
- nullable = kwargs.pop("nullable", Undefined)
- foreign_key = kwargs.pop("foreign_key", Undefined)
+ nullable = kwargs.pop("nullable", PydanticUndefined)
+ foreign_key = kwargs.pop("foreign_key", PydanticUndefined)
unique = kwargs.pop("unique", False)
- index = kwargs.pop("index", Undefined)
- sa_type = kwargs.pop("sa_type", Undefined)
- sa_column = kwargs.pop("sa_column", Undefined)
- sa_column_args = kwargs.pop("sa_column_args", Undefined)
- sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined)
- if sa_column is not Undefined:
- if sa_column_args is not Undefined:
+ index = kwargs.pop("index", PydanticUndefined)
+ sa_type = kwargs.pop("sa_type", PydanticUndefined)
+ sa_column = kwargs.pop("sa_column", PydanticUndefined)
+ sa_column_args = kwargs.pop("sa_column_args", PydanticUndefined)
+ sa_column_kwargs = kwargs.pop("sa_column_kwargs", PydanticUndefined)
+ if sa_column is not PydanticUndefined:
+ if sa_column_args is not PydanticUndefined:
raise RuntimeError(
"Passing sa_column_args is not supported when "
"also passing a sa_column"
)
- if sa_column_kwargs is not Undefined:
+ if sa_column_kwargs is not PydanticUndefined:
raise RuntimeError(
"Passing sa_column_kwargs is not supported when "
"also passing a sa_column"
)
- if primary_key is not Undefined:
+ if primary_key is not PydanticUndefined:
raise RuntimeError(
"Passing primary_key is not supported when "
"also passing a sa_column"
)
- if nullable is not Undefined:
+ if nullable is not PydanticUndefined:
raise RuntimeError(
"Passing nullable is not supported when " "also passing a sa_column"
)
- if foreign_key is not Undefined:
+ if foreign_key is not PydanticUndefined:
raise RuntimeError(
"Passing foreign_key is not supported when "
"also passing a sa_column"
)
- if unique is not Undefined:
+ if unique is not PydanticUndefined:
raise RuntimeError(
"Passing unique is not supported when also passing a sa_column"
)
- if index is not Undefined:
+ if index is not PydanticUndefined:
raise RuntimeError(
"Passing index is not supported when also passing a sa_column"
)
- if sa_type is not Undefined:
+ if sa_type is not PydanticUndefined:
raise RuntimeError(
"Passing sa_type is not supported when also passing a sa_column"
)
@@ -157,7 +159,7 @@ def __init__(
@overload
def Field(
- default: Any = Undefined,
+ default: Any = PydanticUndefined,
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
@@ -186,14 +188,16 @@ def Field(
regex: Optional[str] = None,
discriminator: Optional[str] = None,
repr: bool = True,
- primary_key: Union[bool, UndefinedType] = Undefined,
- foreign_key: Any = Undefined,
- unique: Union[bool, UndefinedType] = Undefined,
- nullable: Union[bool, UndefinedType] = Undefined,
- index: Union[bool, UndefinedType] = Undefined,
- sa_type: Union[Type[Any], UndefinedType] = Undefined,
- sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
- sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
+ primary_key: Union[bool, PydanticUndefinedType] = PydanticUndefined,
+ foreign_key: Any = PydanticUndefined,
+ unique: Union[bool, PydanticUndefinedType] = PydanticUndefined,
+ nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined,
+ index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
+ sa_type: Union[Type[Any], PydanticUndefinedType] = PydanticUndefined,
+ sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
+ sa_column_kwargs: Union[
+ Mapping[str, Any], PydanticUndefinedType
+ ] = PydanticUndefined,
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
...
@@ -201,7 +205,7 @@ def Field(
@overload
def Field(
- default: Any = Undefined,
+ default: Any = PydanticUndefined,
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
@@ -230,14 +234,14 @@ def Field(
regex: Optional[str] = None,
discriminator: Optional[str] = None,
repr: bool = True,
- sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
+ sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
...
def Field(
- default: Any = Undefined,
+ default: Any = PydanticUndefined,
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
@@ -266,15 +270,17 @@ def Field(
regex: Optional[str] = None,
discriminator: Optional[str] = None,
repr: bool = True,
- primary_key: Union[bool, UndefinedType] = Undefined,
- foreign_key: Any = Undefined,
- unique: Union[bool, UndefinedType] = Undefined,
- nullable: Union[bool, UndefinedType] = Undefined,
- index: Union[bool, UndefinedType] = Undefined,
- sa_type: Union[Type[Any], UndefinedType] = Undefined,
- sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
- sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
- sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
+ primary_key: Union[bool, PydanticUndefinedType] = PydanticUndefined,
+ foreign_key: Any = PydanticUndefined,
+ unique: Union[bool, PydanticUndefinedType] = PydanticUndefined,
+ nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined,
+ index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
+ sa_type: Union[Type[Any], PydanticUndefinedType] = PydanticUndefined,
+ sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore
+ sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
+ sa_column_kwargs: Union[
+ Mapping[str, Any], PydanticUndefinedType
+ ] = PydanticUndefined,
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
current_schema_extra = schema_extra or {}
@@ -314,7 +320,6 @@ def Field(
sa_column_kwargs=sa_column_kwargs,
**current_schema_extra,
)
- field_info._validate()
return field_info
@@ -343,7 +348,7 @@ def Relationship(
*,
back_populates: Optional[str] = None,
link_model: Optional[Any] = None,
- sa_relationship: Optional[RelationshipProperty] = None, # type: ignore
+ sa_relationship: Optional[RelationshipProperty] = None,
sa_relationship_args: Optional[Sequence[Any]] = None,
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
) -> Any:
@@ -360,18 +365,22 @@ def Relationship(
@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
__sqlmodel_relationships__: Dict[str, RelationshipInfo]
- __config__: Type[BaseConfig]
- __fields__: Dict[str, ModelField]
+ if IS_PYDANTIC_V2:
+ model_config: SQLModelConfig
+ model_fields: Dict[str, FieldInfo]
+ else:
+ __config__: Type[SQLModelConfig]
+ __fields__: Dict[str, FieldInfo]
# Replicate SQLAlchemy
def __setattr__(cls, name: str, value: Any) -> None:
- if getattr(cls.__config__, "table", False):
+ if get_config_value(cls, "table", False):
DeclarativeMeta.__setattr__(cls, name, value)
else:
super().__setattr__(name, value)
def __delattr__(cls, name: str) -> None:
- if getattr(cls.__config__, "table", False):
+ if get_config_value(cls, "table", False):
DeclarativeMeta.__delattr__(cls, name)
else:
super().__delattr__(name)
@@ -386,9 +395,7 @@ def __new__(
) -> Any:
relationships: Dict[str, RelationshipInfo] = {}
dict_for_pydantic = {}
- original_annotations = resolve_annotations(
- class_dict.get("__annotations__", {}), class_dict.get("__module__", None)
- )
+ original_annotations = get_annotations(class_dict)
pydantic_annotations = {}
relationship_annotations = {}
for k, v in class_dict.items():
@@ -412,17 +419,20 @@ def __new__(
# superclass causing an error
allowed_config_kwargs: Set[str] = {
key
- for key in dir(BaseConfig)
+ for key in dir(PydanticModelConfig)
if not (
key.startswith("__") and key.endswith("__")
) # skip dunder methods and attributes
}
- pydantic_kwargs = kwargs.copy()
config_kwargs = {
- key: pydantic_kwargs.pop(key)
- for key in pydantic_kwargs.keys() & allowed_config_kwargs
+ key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs
}
- new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
+ if class_dict_is_table(class_dict, kwargs):
+ set_empty_defaults(pydantic_annotations, dict_used)
+
+ new_cls: Type["SQLModelMetaclass"] = super().__new__(
+ cls, name, bases, dict_used, **config_kwargs
+ )
new_cls.__annotations__ = {
**relationship_annotations,
**pydantic_annotations,
@@ -430,33 +440,35 @@ def __new__(
}
def get_config(name: str) -> Any:
- config_class_value = getattr(new_cls.__config__, name, Undefined)
- if config_class_value is not Undefined:
+ config_class_value = get_config_value(new_cls, name, PydanticUndefined)
+ if config_class_value is not PydanticUndefined:
return config_class_value
- kwarg_value = kwargs.get(name, Undefined)
- if kwarg_value is not Undefined:
+ kwarg_value = kwargs.get(name, PydanticUndefined)
+ if kwarg_value is not PydanticUndefined:
return kwarg_value
- return Undefined
+ return PydanticUndefined
config_table = get_config("table")
if config_table is True:
# If it was passed by kwargs, ensure it's also set in config
- new_cls.__config__.table = config_table
- for k, v in new_cls.__fields__.items():
+ set_config_value(new_cls, "table", config_table)
+ for k, v in get_model_fields(new_cls).items():
col = get_column_from_field(v)
setattr(new_cls, k, col)
# Set a config flag to tell FastAPI that this should be read with a field
# in orm_mode instead of preemptively converting it to a dict.
- # This could be done by reading new_cls.__config__.table in FastAPI, but
+ # This could be done by reading new_cls.model_config['table'] in FastAPI, but
# that's very specific about SQLModel, so let's have another config that
# other future tools based on Pydantic can use.
- new_cls.__config__.read_with_orm_mode = True
+ set_config_value(
+ new_cls, "read_from_attributes", True, v1_parameter="read_with_orm_mode"
+ )
config_registry = get_config("registry")
- if config_registry is not Undefined:
+ if config_registry is not PydanticUndefined:
config_registry = cast(registry, config_registry)
# If it was passed by kwargs, ensure it's also set in config
- new_cls.__config__.registry = config_table
+ set_config_value(new_cls, "registry", config_table)
setattr(new_cls, "_sa_registry", config_registry) # noqa: B010
setattr(new_cls, "metadata", config_registry.metadata) # noqa: B010
setattr(new_cls, "__abstract__", True) # noqa: B010
@@ -470,13 +482,8 @@ def __init__(
# this allows FastAPI cloning a SQLModel for the response_model without
# trying to create a new SQLAlchemy, for a new table, with the same name, that
# triggers an error
- base_is_table = False
- for base in bases:
- config = getattr(base, "__config__") # noqa: B009
- if config and getattr(config, "table", False):
- base_is_table = True
- break
- if getattr(cls.__config__, "table", False) and not base_is_table:
+ base_is_table = any(cls_is_table(base) for base in bases)
+ if cls_is_table(cls) and not base_is_table:
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
if rel_info.sa_relationship:
# There's a SQLAlchemy relationship declared, that takes precedence
@@ -484,16 +491,7 @@ def __init__(
setattr(cls, rel_name, rel_info.sa_relationship) # Fix #315
continue
ann = cls.__annotations__[rel_name]
- temp_field = ModelField.infer(
- name=rel_name,
- value=rel_info,
- annotation=ann,
- class_validators=None,
- config=BaseConfig,
- )
- relationship_to = temp_field.type_
- if isinstance(temp_field.type_, ForwardRef):
- relationship_to = temp_field.type_.__forward_arg__
+ relationship_to = get_relationship_to(rel_name, rel_info, ann)
rel_kwargs: Dict[str, Any] = {}
if rel_info.back_populates:
rel_kwargs["back_populates"] = rel_info.back_populates
@@ -511,7 +509,7 @@ def __init__(
rel_args.extend(rel_info.sa_relationship_args)
if rel_info.sa_relationship_kwargs:
rel_kwargs.update(rel_info.sa_relationship_kwargs)
- rel_value: RelationshipProperty = relationship( # type: ignore
+ rel_value: RelationshipProperty = relationship(
relationship_to, *rel_args, **rel_kwargs
)
setattr(cls, rel_name, rel_value) # Fix #315
@@ -523,104 +521,6 @@ def __init__(
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
-def get_sqlalchemy_type(field: ModelField) -> Any:
- sa_type = getattr(field.field_info, "sa_type", Undefined) # noqa: B009
- if sa_type is not Undefined:
- return sa_type
- if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
- # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
- if issubclass(field.type_, Enum):
- return sa_Enum(field.type_)
- if issubclass(field.type_, str):
- if field.field_info.max_length:
- return AutoString(length=field.field_info.max_length)
- return AutoString
- if issubclass(field.type_, float):
- return Float
- if issubclass(field.type_, bool):
- return Boolean
- if issubclass(field.type_, int):
- return Integer
- if issubclass(field.type_, datetime):
- return DateTime
- if issubclass(field.type_, date):
- return Date
- if issubclass(field.type_, timedelta):
- return Interval
- if issubclass(field.type_, time):
- return Time
- if issubclass(field.type_, bytes):
- return LargeBinary
- if issubclass(field.type_, Decimal):
- return Numeric(
- precision=getattr(field.type_, "max_digits", None),
- scale=getattr(field.type_, "decimal_places", None),
- )
- if issubclass(field.type_, ipaddress.IPv4Address):
- return AutoString
- if issubclass(field.type_, ipaddress.IPv4Network):
- return AutoString
- if issubclass(field.type_, ipaddress.IPv6Address):
- return AutoString
- if issubclass(field.type_, ipaddress.IPv6Network):
- return AutoString
- if issubclass(field.type_, Path):
- return AutoString
- if issubclass(field.type_, uuid.UUID):
- return GUID
- raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")
-
-
-def get_column_from_field(field: ModelField) -> Column: # type: ignore
- sa_column = getattr(field.field_info, "sa_column", Undefined)
- if isinstance(sa_column, Column):
- return sa_column
- sa_type = get_sqlalchemy_type(field)
- primary_key = getattr(field.field_info, "primary_key", Undefined)
- if primary_key is Undefined:
- primary_key = False
- index = getattr(field.field_info, "index", Undefined)
- if index is Undefined:
- index = False
- nullable = not primary_key and _is_field_noneable(field)
- # Override derived nullability if the nullable property is set explicitly
- # on the field
- field_nullable = getattr(field.field_info, "nullable", Undefined) # noqa: B009
- if field_nullable != Undefined:
- assert not isinstance(field_nullable, UndefinedType)
- nullable = field_nullable
- args = []
- foreign_key = getattr(field.field_info, "foreign_key", Undefined)
- if foreign_key is Undefined:
- foreign_key = None
- unique = getattr(field.field_info, "unique", Undefined)
- if unique is Undefined:
- unique = False
- if foreign_key:
- assert isinstance(foreign_key, str)
- args.append(ForeignKey(foreign_key))
- kwargs = {
- "primary_key": primary_key,
- "nullable": nullable,
- "index": index,
- "unique": unique,
- }
- sa_default = Undefined
- if field.field_info.default_factory:
- sa_default = field.field_info.default_factory
- elif field.field_info.default is not Undefined:
- sa_default = field.field_info.default
- if sa_default is not Undefined:
- kwargs["default"] = sa_default
- sa_column_args = getattr(field.field_info, "sa_column_args", Undefined)
- if sa_column_args is not Undefined:
- args.extend(list(cast(Sequence[Any], sa_column_args)))
- sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined)
- if sa_column_kwargs is not Undefined:
- kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
- return Column(sa_type, *args, **kwargs) # type: ignore
-
-
class_registry = weakref.WeakValueDictionary() # type: ignore
default_registry = registry()
@@ -639,12 +539,17 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
__slots__ = ("__weakref__",)
__tablename__: ClassVar[Union[str, Callable[..., str]]]
- __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore
+ __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]]
__name__: ClassVar[str]
metadata: ClassVar[MetaData]
+ __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six
- class Config:
- orm_mode = True
+ if IS_PYDANTIC_V2:
+ model_config = SQLModelConfig(from_attributes=True)
+ else:
+
+ class Config:
+ orm_mode = True
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
new_object = super().__new__(cls)
@@ -653,28 +558,35 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any:
# Set __fields_set__ here, that would have been set when calling __init__
# in the Pydantic model so that when SQLAlchemy sets attributes that are
# added (e.g. when querying from DB) to the __fields_set__, this already exists
- object.__setattr__(new_object, "__fields_set__", set())
+ set_fields_set(new_object, set())
return new_object
def __init__(__pydantic_self__, **data: Any) -> None:
# Uses something other than `self` the first arg to allow "self" as a
# settable attribute
- values, fields_set, validation_error = validate_model(
- __pydantic_self__.__class__, data
- )
- # Only raise errors if not a SQLModel model
- if (
- not getattr(__pydantic_self__.__config__, "table", False)
- and validation_error
- ):
- raise validation_error
- # Do not set values as in Pydantic, pass them through setattr, so SQLAlchemy
- # can handle them
- # object.__setattr__(__pydantic_self__, '__dict__', values)
- for key, value in values.items():
- setattr(__pydantic_self__, key, value)
- object.__setattr__(__pydantic_self__, "__fields_set__", fields_set)
- non_pydantic_keys = data.keys() - values.keys()
+ if IS_PYDANTIC_V2:
+ old_dict = __pydantic_self__.__dict__.copy()
+ super().__init__(**data) # noqa
+ __pydantic_self__.__dict__ = {**old_dict, **__pydantic_self__.__dict__}
+ non_pydantic_keys = data.keys() - __pydantic_self__.model_fields
+ else:
+ values, fields_set, validation_error = validate_model(
+ __pydantic_self__.__class__, data
+ )
+ # Only raise errors if not a SQLModel model
+ if (
+ not getattr(__pydantic_self__.__config__, "table", False) # noqa
+ and validation_error
+ ):
+ raise validation_error
+ # Do not set values as in Pydantic, pass them through setattr, so SQLAlchemy
+ # can handle them
+ # object.__setattr__(__pydantic_self__, '__dict__', values)
+ for key, value in values.items():
+ setattr(__pydantic_self__, key, value)
+ object.__setattr__(__pydantic_self__, "__fields_set__", fields_set)
+ non_pydantic_keys = data.keys() - values.keys()
+
for key in non_pydantic_keys:
if key in __pydantic_self__.__sqlmodel_relationships__:
setattr(__pydantic_self__, key, data[key])
@@ -685,59 +597,13 @@ def __setattr__(self, name: str, value: Any) -> None:
return
else:
# Set in SQLAlchemy, before Pydantic to trigger events and updates
- if getattr(self.__config__, "table", False) and is_instrumented(self, name):
+ if get_config_value(self, "table", False) and is_instrumented(self, name):
set_attribute(self, name, value)
# Set in Pydantic model to trigger possible validation changes, only for
# non relationship values
if name not in self.__sqlmodel_relationships__:
super().__setattr__(name, value)
- @classmethod
- def from_orm(
- cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None
- ) -> _TSQLModel:
- # Duplicated from Pydantic
- if not cls.__config__.orm_mode:
- raise ConfigError(
- "You must have the config attribute orm_mode=True to use from_orm"
- )
- obj = {ROOT_KEY: obj} if cls.__custom_root_type__ else cls._decompose_class(obj)
- # SQLModel, support update dict
- if update is not None:
- obj = {**obj, **update}
- # End SQLModel support dict
- if not getattr(cls.__config__, "table", False):
- # If not table, normal Pydantic code
- m: _TSQLModel = cls.__new__(cls)
- else:
- # If table, create the new instance normally to make SQLAlchemy create
- # the _sa_instance_state attribute
- m = cls()
- values, fields_set, validation_error = validate_model(cls, obj)
- if validation_error:
- raise validation_error
- # Updated to trigger SQLAlchemy internal handling
- if not getattr(cls.__config__, "table", False):
- object.__setattr__(m, "__dict__", values)
- else:
- for key, value in values.items():
- setattr(m, key, value)
- # Continue with standard Pydantic logic
- object.__setattr__(m, "__fields_set__", fields_set)
- m._init_private_attributes()
- return m
-
- @classmethod
- def parse_obj(
- cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None
- ) -> _TSQLModel:
- obj = cls._enforce_dict_if_root(obj)
- # SQLModel, support update dict
- if update is not None:
- obj = {**obj, **update}
- # End SQLModel support dict
- return super().parse_obj(obj)
-
def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
# Don't show SQLAlchemy private attributes
return [
@@ -746,78 +612,144 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
if not (isinstance(k, str) and k.startswith("_sa_"))
]
- # From Pydantic, override to enforce validation with dict
- @classmethod
- def validate(cls: Type[_TSQLModel], value: Any) -> _TSQLModel:
- if isinstance(value, cls):
- return value.copy() if cls.__config__.copy_on_model_validation else value
-
- value = cls._enforce_dict_if_root(value)
- if isinstance(value, dict):
- values, fields_set, validation_error = validate_model(cls, value)
- if validation_error:
- raise validation_error
- model = cls(**value)
- # Reset fields set, this would have been done in Pydantic in __init__
- object.__setattr__(model, "__fields_set__", fields_set)
- return model
- elif cls.__config__.orm_mode:
- return cls.from_orm(value)
- elif cls.__custom_root_type__:
- return cls.parse_obj(value)
- else:
- try:
- value_as_dict = dict(value)
- except (TypeError, ValueError) as e:
- raise DictError() from e
- return cls(**value_as_dict)
-
- # From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes
- def _calculate_keys(
- self,
- include: Optional[Mapping[Union[int, str], Any]],
- exclude: Optional[Mapping[Union[int, str], Any]],
- exclude_unset: bool,
- update: Optional[Dict[str, Any]] = None,
- ) -> Optional[AbstractSet[str]]:
- if include is None and exclude is None and not exclude_unset:
- # Original in Pydantic:
- # return None
- # Updated to not return SQLAlchemy attributes
- # Do not include relationships as that would easily lead to infinite
- # recursion, or traversing the whole database
- return self.__fields__.keys() # | self.__sqlmodel_relationships__.keys()
-
- keys: AbstractSet[str]
- if exclude_unset:
- keys = self.__fields_set__.copy()
- else:
- # Original in Pydantic:
- # keys = self.__dict__.keys()
- # Updated to not return SQLAlchemy attributes
- # Do not include relationships as that would easily lead to infinite
- # recursion, or traversing the whole database
- keys = self.__fields__.keys() # | self.__sqlmodel_relationships__.keys()
- if include is not None:
- keys &= include.keys()
-
- if update:
- keys -= update.keys()
-
- if exclude:
- keys -= {k for k, v in exclude.items() if _value_items_is_true(v)}
-
- return keys
-
@declared_attr # type: ignore
def __tablename__(cls) -> str:
return cls.__name__.lower()
+ if IS_PYDANTIC_V2:
+
+ @classmethod
+ def model_validate(
+ cls: type[_TSQLModel],
+ obj: Any,
+ *,
+ strict: bool | None = None,
+ from_attributes: bool | None = None,
+ context: dict[str, Any] | None = None,
+ ) -> _TSQLModel:
+ # Somehow model validate doesn't call __init__ so it would remove our init logic
+ validated = super().model_validate(
+ obj, strict=strict, from_attributes=from_attributes, context=context
+ )
+ return cls(**validated.model_dump(exclude_unset=True))
+
+ else:
+
+ @classmethod
+ def from_orm(
+ cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None
+ ) -> _TSQLModel:
+ # Duplicated from Pydantic
+ if not cls.__config__.orm_mode: # noqa
+ raise ConfigError(
+ "You must have the config attribute orm_mode=True to use from_orm"
+ )
+ obj = (
+ {ROOT_KEY: obj}
+ if cls.__custom_root_type__ # noqa
+ else cls._decompose_class(obj) # noqa
+ )
+ # SQLModel, support update dict
+ if update is not None:
+ obj = {**obj, **update}
+ # End SQLModel support dict
+ if not getattr(cls.__config__, "table", False): # noqa
+ # If not table, normal Pydantic code
+ m: _TSQLModel = cls.__new__(cls)
+ else:
+ # If table, create the new instance normally to make SQLAlchemy create
+ # the _sa_instance_state attribute
+ m = cls()
+ values, fields_set, validation_error = validate_model(cls, obj)
+ if validation_error:
+ raise validation_error
+ # Updated to trigger SQLAlchemy internal handling
+ if not getattr(cls.__config__, "table", False): # noqa
+ object.__setattr__(m, "__dict__", values)
+ else:
+ for key, value in values.items():
+ setattr(m, key, value)
+ # Continue with standard Pydantic logic
+ object.__setattr__(m, "__fields_set__", fields_set)
+ m._init_private_attributes() # noqa
+ return m
+
+ @classmethod
+ def parse_obj(
+ cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None
+ ) -> _TSQLModel:
+ obj = cls._enforce_dict_if_root(obj) # noqa
+ # SQLModel, support update dict
+ if update is not None:
+ obj = {**obj, **update}
+ # End SQLModel support dict
+ return super().parse_obj(obj)
+
+ # From Pydantic, override to enforce validation with dict
+ @classmethod
+ def validate(cls: Type[_TSQLModel], value: Any) -> _TSQLModel:
+ if isinstance(value, cls):
+ return (
+ value.copy() if cls.__config__.copy_on_model_validation else value # noqa
+ )
-def _is_field_noneable(field: ModelField) -> bool:
- if not field.required:
- # Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947)
- return field.allow_none and (
- field.shape != SHAPE_SINGLETON or not field.sub_fields
- )
- return False
+ value = cls._enforce_dict_if_root(value)
+ if isinstance(value, dict):
+ values, fields_set, validation_error = validate_model(cls, value)
+ if validation_error:
+ raise validation_error
+ model = cls(**value)
+ # Reset fields set, this would have been done in Pydantic in __init__
+ object.__setattr__(model, "__fields_set__", fields_set)
+ return model
+ elif cls.__config__.orm_mode: # noqa
+ return cls.from_orm(value)
+ elif cls.__custom_root_type__: # noqa
+ return cls.parse_obj(value)
+ else:
+ try:
+ value_as_dict = dict(value)
+ except (TypeError, ValueError) as e:
+ raise DictError() from e
+ return cls(**value_as_dict)
+
+ # From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes
+ def _calculate_keys(
+ self,
+ include: Optional[Mapping[Union[int, str], Any]],
+ exclude: Optional[Mapping[Union[int, str], Any]],
+ exclude_unset: bool,
+ update: Optional[Dict[str, Any]] = None,
+ ) -> Optional[AbstractSet[str]]:
+ if include is None and exclude is None and not exclude_unset:
+ # Original in Pydantic:
+ # return None
+ # Updated to not return SQLAlchemy attributes
+ # Do not include relationships as that would easily lead to infinite
+ # recursion, or traversing the whole database
+ return (
+ self.__fields__.keys() # noqa
+ ) # | self.__sqlmodel_relationships__.keys()
+
+ keys: AbstractSet[str]
+ if exclude_unset:
+ keys = self.__fields_set__.copy() # noqa
+ else:
+ # Original in Pydantic:
+ # keys = self.__dict__.keys()
+ # Updated to not return SQLAlchemy attributes
+ # Do not include relationships as that would easily lead to infinite
+ # recursion, or traversing the whole database
+ keys = (
+ self.__fields__.keys() # noqa
+ ) # | self.__sqlmodel_relationships__.keys()
+ if include is not None:
+ keys &= include.keys()
+
+ if update:
+ keys -= update.keys()
+
+ if exclude:
+ keys -= {k for k, v in exclude.items() if _value_items_is_true(v)}
+
+ return keys
diff --git a/sqlmodel/sql/expression.py.jinja2 b/sqlmodel/sql/expression.py.jinja2
index 26d12a0395..b3acb22c95 100644
--- a/sqlmodel/sql/expression.py.jinja2
+++ b/sqlmodel/sql/expression.py.jinja2
@@ -91,4 +91,4 @@ def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type:
def col(column_expression: Any) -> ColumnClause: # type: ignore
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
- return column_expression
+ return column_expression
\ No newline at end of file
diff --git a/tests/conftest.py b/tests/conftest.py
index 2b8e5fc29e..020b33a566 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -6,6 +6,7 @@
import pytest
from pydantic import BaseModel
from sqlmodel import SQLModel
+from sqlmodel.compat import IS_PYDANTIC_V2
from sqlmodel.main import default_registry
top_level_path = Path(__file__).resolve().parent.parent
@@ -67,3 +68,7 @@ def new_print(*args):
calls.append(data)
return new_print
+
+
+needs_pydanticv2 = pytest.mark.skipif(not IS_PYDANTIC_V2, reason="requires Pydantic v2")
+needs_pydanticv1 = pytest.mark.skipif(IS_PYDANTIC_V2, reason="requires Pydantic v1")
diff --git a/tests/test_enums.py b/tests/test_enums.py
index 194bdefea1..07a04c686e 100644
--- a/tests/test_enums.py
+++ b/tests/test_enums.py
@@ -5,6 +5,8 @@
from sqlalchemy.sql.type_api import TypeEngine
from sqlmodel import Field, SQLModel
+from .conftest import needs_pydanticv1, needs_pydanticv2
+
"""
Tests related to Enums
@@ -72,7 +74,8 @@ def test_sqlite_ddl_sql(capsys):
assert "CREATE TYPE" not in captured.out
-def test_json_schema_flat_model():
+@needs_pydanticv1
+def test_json_schema_flat_model_pydantic_v1():
assert FlatModel.schema() == {
"title": "FlatModel",
"type": "object",
@@ -92,7 +95,8 @@ def test_json_schema_flat_model():
}
-def test_json_schema_inherit_model():
+@needs_pydanticv1
+def test_json_schema_inherit_model_pydantic_v1():
assert InheritModel.schema() == {
"title": "InheritModel",
"type": "object",
@@ -110,3 +114,42 @@ def test_json_schema_inherit_model():
}
},
}
+
+
+@needs_pydanticv2
+def test_json_schema_flat_model_pydantic_v2():
+ assert FlatModel.model_json_schema() == {
+ "title": "FlatModel",
+ "type": "object",
+ "properties": {
+ "id": {"default": None, "format": "uuid", "title": "Id", "type": "string"},
+ "enum_field": {"allOf": [{"$ref": "#/$defs/MyEnum1"}], "default": None},
+ },
+ "$defs": {
+ "MyEnum1": {
+ "title": "MyEnum1",
+ "enum": ["A", "B"],
+ "type": "string",
+ }
+ },
+ }
+
+
+@needs_pydanticv2
+def test_json_schema_inherit_model_pydantic_v2():
+ assert InheritModel.model_json_schema() == {
+ "title": "InheritModel",
+ "type": "object",
+ "properties": {
+ "id": {"title": "Id", "type": "string", "format": "uuid"},
+ "enum_field": {"$ref": "#/$defs/MyEnum2"},
+ },
+ "required": ["id", "enum_field"],
+ "$defs": {
+ "MyEnum2": {
+ "title": "MyEnum2",
+ "enum": ["C", "D"],
+ "type": "string",
+ }
+ },
+ }
diff --git a/tests/test_instance_no_args.py b/tests/test_instance_no_args.py
index 14d560628b..e54e8163b3 100644
--- a/tests/test_instance_no_args.py
+++ b/tests/test_instance_no_args.py
@@ -1,11 +1,16 @@
from typing import Optional
+import pytest
+from pydantic import ValidationError
from sqlalchemy import create_engine, select
from sqlalchemy.orm import Session
from sqlmodel import Field, SQLModel
+from .conftest import needs_pydanticv1, needs_pydanticv2
-def test_allow_instantiation_without_arguments(clear_sqlmodel):
+
+@needs_pydanticv1
+def test_allow_instantiation_without_arguments_pydantic_v1(clear_sqlmodel):
class Item(SQLModel):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
@@ -25,3 +30,33 @@ class Config:
assert len(result) == 1
assert isinstance(item.id, int)
SQLModel.metadata.clear()
+
+
+def test_not_allow_instantiation_without_arguments_if_not_table():
+ class Item(SQLModel):
+ id: Optional[int] = Field(default=None, primary_key=True)
+ name: str
+ description: Optional[str] = None
+
+ with pytest.raises(ValidationError):
+ Item()
+
+
+@needs_pydanticv2
+def test_allow_instantiation_without_arguments_pydnatic_v2(clear_sqlmodel):
+ class Item(SQLModel, table=True):
+ id: Optional[int] = Field(default=None, primary_key=True)
+ name: str
+ description: Optional[str] = None
+
+ engine = create_engine("sqlite:///:memory:")
+ SQLModel.metadata.create_all(engine)
+ with Session(engine) as db:
+ item = Item()
+ item.name = "Rick"
+ db.add(item)
+ db.commit()
+ result = db.execute(select(Item)).scalars().all()
+ assert len(result) == 1
+ assert isinstance(item.id, int)
+ SQLModel.metadata.clear()
diff --git a/tests/test_missing_type.py b/tests/test_missing_type.py
index 2185fa43e9..dc31f053ec 100644
--- a/tests/test_missing_type.py
+++ b/tests/test_missing_type.py
@@ -1,11 +1,12 @@
from typing import Optional
import pytest
+from pydantic import BaseModel
from sqlmodel import Field, SQLModel
def test_missing_sql_type():
- class CustomType:
+ class CustomType(BaseModel):
@classmethod
def __get_validators__(cls):
yield cls.validate
diff --git a/tests/test_nullable.py b/tests/test_nullable.py
index 1c8b37b218..a40bb5b5f0 100644
--- a/tests/test_nullable.py
+++ b/tests/test_nullable.py
@@ -58,7 +58,7 @@ class Hero(SQLModel, table=True):
][0]
assert "primary_key INTEGER NOT NULL," in create_table_log
assert "required_value VARCHAR NOT NULL," in create_table_log
- assert "optional_default_ellipsis VARCHAR NOT NULL," in create_table_log
+ assert "optional_default_ellipsis VARCHAR," in create_table_log
assert "optional_default_none VARCHAR," in create_table_log
assert "optional_non_nullable VARCHAR NOT NULL," in create_table_log
assert "optional_nullable VARCHAR," in create_table_log
diff --git a/tests/test_validation.py b/tests/test_validation.py
index ad60fcb945..3265922070 100644
--- a/tests/test_validation.py
+++ b/tests/test_validation.py
@@ -1,12 +1,14 @@
from typing import Optional
import pytest
-from pydantic import validator
from pydantic.error_wrappers import ValidationError
from sqlmodel import SQLModel
+from .conftest import needs_pydanticv1, needs_pydanticv2
-def test_validation(clear_sqlmodel):
+
+@needs_pydanticv1
+def test_validation_pydantic_v1(clear_sqlmodel):
"""Test validation of implicit and explicit None values.
# For consistency with pydantic, validators are not to be called on
@@ -16,6 +18,7 @@ def test_validation(clear_sqlmodel):
https://github.com/samuelcolvin/pydantic/issues/1223
"""
+ from pydantic import validator
class Hero(SQLModel):
name: Optional[str] = None
@@ -31,3 +34,32 @@ def reject_none(cls, v):
with pytest.raises(ValidationError):
Hero.validate({"name": None, "age": 25})
+
+
+@needs_pydanticv2
+def test_validation_pydantic_v2(clear_sqlmodel):
+ """Test validation of implicit and explicit None values.
+
+ # For consistency with pydantic, validators are not to be called on
+ # arguments that are not explicitly provided.
+
+ https://github.com/tiangolo/sqlmodel/issues/230
+ https://github.com/samuelcolvin/pydantic/issues/1223
+
+ """
+ from pydantic import field_validator
+
+ class Hero(SQLModel):
+ name: Optional[str] = None
+ secret_name: Optional[str] = None
+ age: Optional[int] = None
+
+ @field_validator("name", "secret_name", "age")
+ def reject_none(cls, v):
+ assert v is not None
+ return v
+
+ Hero.model_validate({"age": 25})
+
+ with pytest.raises(ValidationError):
+ Hero.model_validate({"name": None, "age": 25})