diff --git a/README.md b/README.md
index 490cbe2..a36c6e2 100644
--- a/README.md
+++ b/README.md
@@ -9,12 +9,6 @@
FastAPI OAuth2 is a middleware-based social authentication mechanism supporting several auth providers. It depends on
the [social-core](https://github.com/python-social-auth/social-core) authentication backends.
-## Features to be implemented
-
-- Use multiple OAuth2 providers at the same time
- * There need to be provided a way to configure the OAuth2 for multiple providers
-- Customizable OAuth2 routes
-
## Installation
```shell
diff --git a/examples/demonstration/.env b/examples/demonstration/.env
index a1c0106..25f028b 100644
--- a/examples/demonstration/.env
+++ b/examples/demonstration/.env
@@ -1,5 +1,10 @@
-OAUTH2_CLIENT_ID=eccd08d6736b7999a32a
-OAUTH2_CLIENT_SECRET=642999c1c5f2b3df8b877afdc78252ef5b594d31
+# These id and secret are generated especially for testing purposes,
+# if you have your own, please use them, otherwise you can use these.
+OAUTH2_GITHUB_CLIENT_ID=eccd08d6736b7999a32a
+OAUTH2_GITHUB_CLIENT_SECRET=642999c1c5f2b3df8b877afdc78252ef5b594d31
+
+OAUTH2_GOOGLE_CLIENT_ID=105851609656-uueuan570963mnnf4288nv40eieh9f5l.apps.googleusercontent.com
+OAUTH2_GOOGLE_CLIENT_SECRET=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
JWT_SECRET=secret
JWT_ALGORITHM=HS256
diff --git a/examples/demonstration/config.py b/examples/demonstration/config.py
index 935c2b1..be64b0f 100644
--- a/examples/demonstration/config.py
+++ b/examples/demonstration/config.py
@@ -2,6 +2,7 @@
from dotenv import load_dotenv
from social_core.backends.github import GithubOAuth2
+from social_core.backends.google import GoogleOAuth2
from fastapi_oauth2.claims import Claims
from fastapi_oauth2.client import OAuth2Client
@@ -17,14 +18,22 @@
clients=[
OAuth2Client(
backend=GithubOAuth2,
- client_id=os.getenv("OAUTH2_CLIENT_ID"),
- client_secret=os.getenv("OAUTH2_CLIENT_SECRET"),
- # redirect_uri="http://127.0.0.1:8000/",
+ client_id=os.getenv("OAUTH2_GITHUB_CLIENT_ID"),
+ client_secret=os.getenv("OAUTH2_GITHUB_CLIENT_SECRET"),
scope=["user:email"],
claims=Claims(
picture="avatar_url",
identity=lambda user: "%s:%s" % (user.get("provider"), user.get("id")),
),
),
+ OAuth2Client(
+ backend=GoogleOAuth2,
+ client_id=os.getenv("OAUTH2_GOOGLE_CLIENT_ID"),
+ client_secret=os.getenv("OAUTH2_GOOGLE_CLIENT_SECRET"),
+ scope=["openid", "profile", "email"],
+ claims=Claims(
+ identity=lambda user: "%s:%s" % (user.get("provider"), user.get("sub")),
+ ),
+ ),
]
)
diff --git a/examples/demonstration/main.py b/examples/demonstration/main.py
index e657bf1..4b78238 100644
--- a/examples/demonstration/main.py
+++ b/examples/demonstration/main.py
@@ -1,5 +1,6 @@
from fastapi import APIRouter
from fastapi import FastAPI
+from fastapi.staticfiles import StaticFiles
from sqlalchemy.orm import Session
from config import oauth2_config
@@ -24,16 +25,18 @@ async def on_auth(auth: Auth, user: User):
db: Session = next(get_db())
query = db.query(UserModel)
if user.identity and not query.filter_by(identity=user.identity).first():
+ # create a local user by OAuth2 user's data if it does not exist yet
UserModel(**{
- "identity": user.get("identity"),
- "username": user.get("username"),
- "image": user.get("image"),
- "email": user.get("email"),
- "name": user.get("name"),
+ "identity": user.identity, # User property
+ "username": user.get("username"), # custom attribute
+ "name": user.display_name, # User property
+ "image": user.picture, # User property
+ "email": user.email, # User property
}).save(db)
app = FastAPI()
app.include_router(app_router)
app.include_router(oauth2_router)
+app.mount("/static", StaticFiles(directory="static"), name="static")
app.add_middleware(OAuth2Middleware, config=oauth2_config, callback=on_auth)
diff --git a/examples/demonstration/static/github.svg b/examples/demonstration/static/github.svg
new file mode 100644
index 0000000..75a94ed
--- /dev/null
+++ b/examples/demonstration/static/github.svg
@@ -0,0 +1,5 @@
+
+
\ No newline at end of file
diff --git a/examples/demonstration/static/google-oauth2.svg b/examples/demonstration/static/google-oauth2.svg
new file mode 100644
index 0000000..ac18388
--- /dev/null
+++ b/examples/demonstration/static/google-oauth2.svg
@@ -0,0 +1,6 @@
+
+
\ No newline at end of file
diff --git a/examples/demonstration/templates/index.html b/examples/demonstration/templates/index.html
index 9a8b81d..caea8e5 100644
--- a/examples/demonstration/templates/index.html
+++ b/examples/demonstration/templates/index.html
@@ -21,11 +21,15 @@
Simulate Login
-
-
-
+ {% for provider in request.auth.clients %}
+
+
+
+ {% endfor %}
{% endif %}
@@ -33,6 +37,14 @@
style="display: flex; flex-direction: column; align-items: center; justify-content: center; height: calc(100vh - 70px);">
{% if request.user.is_authenticated %}
This is what your JWT contains currently
{{ json.dumps(request.user, indent=4) }}{% else %} diff --git a/setup.cfg b/setup.cfg index ed46db8..81b6a1d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,7 +27,7 @@ license_files = LICENSE platforms = unix, linux, osx, win32 classifiers = Operating System :: OS Independent - Development Status :: 2 - Pre-Alpha + Development Status :: 3 - Alpha Framework :: FastAPI Programming Language :: Python Programming Language :: Python :: 3 diff --git a/src/fastapi_oauth2/__init__.py b/src/fastapi_oauth2/__init__.py index a390618..5186ae4 100644 --- a/src/fastapi_oauth2/__init__.py +++ b/src/fastapi_oauth2/__init__.py @@ -1 +1 @@ -__version__ = "1.0.0-alpha.1" +__version__ = "1.0.0-alpha.2" diff --git a/src/fastapi_oauth2/core.py b/src/fastapi_oauth2/core.py index a9e7291..3a4ea18 100644 --- a/src/fastapi_oauth2/core.py +++ b/src/fastapi_oauth2/core.py @@ -10,6 +10,7 @@ import httpx from oauthlib.oauth2 import WebApplicationClient +from oauthlib.oauth2.rfc6749.errors import CustomOAuth2Error from social_core.backends.oauth import BaseOAuth2 from social_core.strategy import BaseStrategy from starlette.exceptions import HTTPException @@ -46,9 +47,10 @@ class OAuth2Core: client_id: str = None client_secret: str = None - callback_url: Optional[str] = None scope: Optional[List[str]] = None claims: Optional[Claims] = None + provider: str = None + redirect_uri: str = None backend: BaseOAuth2 = None _oauth_client: Optional[WebApplicationClient] = None @@ -108,9 +110,12 @@ async def token_redirect(self, request: Request) -> RedirectResponse: auth = httpx.BasicAuth(self.client_id, self.client_secret) async with httpx.AsyncClient() as session: response = await session.post(token_url, headers=headers, content=content, auth=auth) - token = self.oauth_client.parse_request_body_response(json.dumps(response.json())) - token_data = self.standardize(self.backend.user_data(token.get("access_token"))) - access_token = request.auth.jwt_create(token_data) + try: + token = self.oauth_client.parse_request_body_response(json.dumps(response.json())) + token_data = self.standardize(self.backend.user_data(token.get("access_token"))) + access_token = request.auth.jwt_create(token_data) + except (CustomOAuth2Error, Exception) as e: + raise OAuth2LoginError(400, str(e)) response = RedirectResponse(self.redirect_uri or request.base_url) response.set_cookie( diff --git a/src/fastapi_oauth2/middleware.py b/src/fastapi_oauth2/middleware.py index c921f7b..5dd5eb1 100644 --- a/src/fastapi_oauth2/middleware.py +++ b/src/fastapi_oauth2/middleware.py @@ -6,7 +6,6 @@ from typing import Dict from typing import List from typing import Optional -from typing import Sequence from typing import Tuple from typing import Union @@ -39,16 +38,15 @@ class Auth(AuthCredentials): scopes: List[str] clients: Dict[str, OAuth2Core] = {} - provider: str - default_provider: str = "local" + _provider: OAuth2Core = None - def __init__( - self, - scopes: Optional[Sequence[str]] = None, - provider: str = default_provider, - ) -> None: - super().__init__(scopes) - self.provider = provider + @property + def provider(self) -> Union[OAuth2Core, None]: + return self._provider + + @provider.setter + def provider(self, identifier) -> None: + self._provider = self.clients.get(identifier) @classmethod def set_http(cls, http: bool) -> None: @@ -145,18 +143,16 @@ async def authenticate(self, request: Request) -> Optional[Tuple[Auth, User]]: return Auth(), User() user = User(Auth.jwt_decode(param)) - user.update(provider=user.get("provider", Auth.default_provider)) - auth = Auth(user.pop("scope", []), user.get("provider")) - client = Auth.clients.get(auth.provider) - claims = client.claims if client else Claims() - user = user.use_claims(claims) + auth = Auth(user.pop("scope", [])) + auth.provider = user.get("provider") + claims = auth.provider.claims if auth.provider else {} # Call the callback function on authentication if callable(self.callback): - coroutine = self.callback(auth, user) + coroutine = self.callback(auth, user.use_claims(claims)) if issubclass(type(coroutine), Awaitable): await coroutine - return auth, user + return auth, user.use_claims(claims) class OAuth2Middleware: diff --git a/src/fastapi_oauth2/security.py b/src/fastapi_oauth2/security.py index 0f5d3b3..fddc067 100644 --- a/src/fastapi_oauth2/security.py +++ b/src/fastapi_oauth2/security.py @@ -1,8 +1,4 @@ -from typing import Any -from typing import Callable -from typing import Dict from typing import Optional -from typing import Tuple from typing import Type from fastapi.security import OAuth2 as FastAPIOAuth2 @@ -12,32 +8,29 @@ from starlette.requests import Request -def use_cookies(cls: Type[FastAPIOAuth2]) -> Callable[[Tuple[Any], Dict[str, Any]], FastAPIOAuth2]: - """OAuth2 classes wrapped with this decorator will use cookies for the Authorization header.""" +class OAuth2Cookie(type): + """OAuth2 classes using this metaclass will use cookies for the Authorization header.""" + + def __new__(metacls, name, bases, attrs) -> Type: + instance = super().__new__(metacls, name, bases, attrs) - def _use_cookies(*args, **kwargs) -> FastAPIOAuth2: async def __call__(self: FastAPIOAuth2, request: Request) -> Optional[str]: authorization = request.headers.get("Authorization", request.cookies.get("Authorization")) if authorization: request._headers = Headers({**request.headers, "Authorization": authorization}) - return await super(cls, self).__call__(request) - - cls.__call__ = __call__ - return cls(*args, **kwargs) + return await instance.__base__.__call__(self, request) - return _use_cookies + instance.__call__ = __call__ + return instance -@use_cookies -class OAuth2(FastAPIOAuth2): +class OAuth2(FastAPIOAuth2, metaclass=OAuth2Cookie): """Wrapper class of the `fastapi.security.OAuth2` class.""" -@use_cookies -class OAuth2PasswordBearer(FastAPIPasswordBearer): +class OAuth2PasswordBearer(FastAPIPasswordBearer, metaclass=OAuth2Cookie): """Wrapper class of the `fastapi.security.OAuth2PasswordBearer` class.""" -@use_cookies -class OAuth2AuthorizationCodeBearer(FastAPICodeBearer): +class OAuth2AuthorizationCodeBearer(FastAPICodeBearer, metaclass=OAuth2Cookie): """Wrapper class of the `fastapi.security.OAuth2AuthorizationCodeBearer` class.""" diff --git a/tests/conftest.py b/tests/conftest.py index aedb52a..b96e6c5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,18 @@ import pytest import social_core.backends as backends +from fastapi import APIRouter +from fastapi import Depends +from fastapi import FastAPI +from fastapi import Request +from social_core.backends.github import GithubOAuth2 from social_core.backends.oauth import BaseOAuth2 +from starlette.responses import Response + +from fastapi_oauth2.client import OAuth2Client +from fastapi_oauth2.middleware import OAuth2Middleware +from fastapi_oauth2.router import router as oauth2_router +from fastapi_oauth2.security import OAuth2 package_path = backends.__path__[0] @@ -24,3 +35,52 @@ def backends(): except ImportError: continue return backend_instances + + +@pytest.fixture +def get_app(): + def fixture_wrapper(authentication: OAuth2 = None): + if not authentication: + authentication = OAuth2() + + oauth2 = authentication + application = FastAPI() + app_router = APIRouter() + + @app_router.get("/user") + def user(request: Request, _: str = Depends(oauth2)): + return request.user + + @app_router.get("/auth") + def auth(request: Request): + access_token = request.auth.jwt_create({ + "name": "test", + "sub": "test", + "id": "test", + }) + response = Response() + response.set_cookie( + "Authorization", + value=f"Bearer {access_token}", + max_age=request.auth.expires, + expires=request.auth.expires, + httponly=request.auth.http, + ) + return response + + application.include_router(app_router) + application.include_router(oauth2_router) + application.add_middleware(OAuth2Middleware, config={ + "allow_http": True, + "clients": [ + OAuth2Client( + backend=GithubOAuth2, + client_id="test_id", + client_secret="test_secret", + ), + ], + }) + + return application + + return fixture_wrapper diff --git a/tests/test_backends.py b/tests/test_backends.py new file mode 100644 index 0000000..47a91d6 --- /dev/null +++ b/tests/test_backends.py @@ -0,0 +1,17 @@ +import pytest + +from fastapi_oauth2.client import OAuth2Client +from fastapi_oauth2.core import OAuth2Core + + +@pytest.mark.anyio +async def test_core_init_with_all_backends(backends): + for backend in backends: + try: + OAuth2Core(OAuth2Client( + backend=backend, + client_id="test_client_id", + client_secret="test_client_secret", + )) + except (NotImplementedError, Exception): + assert False diff --git a/tests/test_middleware.py b/tests/test_middleware.py new file mode 100644 index 0000000..e33c6b7 --- /dev/null +++ b/tests/test_middleware.py @@ -0,0 +1,28 @@ +import pytest +from httpx import AsyncClient + + +@pytest.mark.anyio +async def test_middleware_on_authentication(get_app): + async with AsyncClient(app=get_app(), base_url="http://test") as client: + response = await client.get("/user") + assert response.status_code == 403 # Forbidden + + await client.get("/auth") # Simulate login + + response = await client.get("/user") + assert response.status_code == 200 # OK + + +@pytest.mark.anyio +async def test_middleware_on_logout(get_app): + async with AsyncClient(app=get_app(), base_url="http://test") as client: + await client.get("/auth") # Simulate login + + response = await client.get("/user") + assert response.status_code == 200 # OK + + await client.get("/oauth2/logout") # Perform logout + + response = await client.get("/user") + assert response.status_code == 403 # Forbidden diff --git a/tests/test_oauth2_middleware.py b/tests/test_oauth2_middleware.py deleted file mode 100644 index 27d0752..0000000 --- a/tests/test_oauth2_middleware.py +++ /dev/null @@ -1,87 +0,0 @@ -import pytest -from fastapi import APIRouter -from fastapi import Depends -from fastapi import FastAPI -from fastapi import Request -from httpx import AsyncClient -from social_core.backends.github import GithubOAuth2 -from starlette.responses import Response - -from fastapi_oauth2.client import OAuth2Client -from fastapi_oauth2.core import OAuth2Core -from fastapi_oauth2.middleware import OAuth2Middleware -from fastapi_oauth2.router import router as oauth2_router -from fastapi_oauth2.security import OAuth2 - -app = FastAPI() -oauth2 = OAuth2() -app_router = APIRouter() - - -@app_router.get("/user") -def user(request: Request, _: str = Depends(oauth2)): - return request.user - - -@app_router.get("/auth") -def auth(request: Request): - access_token = request.auth.jwt_create({ - "name": "test", - "sub": "test", - "id": "test", - }) - response = Response() - response.set_cookie( - "Authorization", - value=f"Bearer {access_token}", - max_age=request.auth.expires, - expires=request.auth.expires, - httponly=request.auth.http, - ) - return response - - -app.include_router(app_router) -app.include_router(oauth2_router) -app.add_middleware(OAuth2Middleware, config={ - "allow_http": True, - "clients": [ - OAuth2Client( - backend=GithubOAuth2, - client_id="test_id", - client_secret="test_secret", - ), - ], -}) - - -@pytest.mark.anyio -async def test_auth_redirect(): - async with AsyncClient(app=app, base_url="http://test") as client: - response = await client.get("/oauth2/github/auth") - assert response.status_code == 303 # Redirect - - -@pytest.mark.anyio -async def test_authenticated_request(): - async with AsyncClient(app=app, base_url="http://test") as client: - response = await client.get("/user") - assert response.status_code == 403 # Forbidden - - await client.get("/auth") # Simulate login - - response = await client.get("/user") - assert response.status_code == 200 # OK - - -@pytest.mark.anyio -async def test_core_init(backends): - for backend in backends: - try: - OAuth2Core(OAuth2Client( - backend=backend, - client_id="test_client_id", - client_secret="test_client_secret", - )) - except (NotImplementedError, Exception): - assert False diff --git a/tests/test_router.py b/tests/test_router.py new file mode 100644 index 0000000..084f459 --- /dev/null +++ b/tests/test_router.py @@ -0,0 +1,26 @@ +import pytest +from httpx import AsyncClient + + +@pytest.mark.anyio +async def test_auth_redirect(get_app): + async with AsyncClient(app=get_app(), base_url="http://test") as client: + response = await client.get("/oauth2/github/auth") + assert response.status_code == 303 # Redirect + + +@pytest.mark.anyio +async def test_token_redirect(get_app): + async with AsyncClient(app=get_app(), base_url="http://test") as client: + response = await client.get("/oauth2/github/token") + assert response.status_code == 400 # Bad Request + + response = await client.get("/oauth2/github/token?state=test&code=test") + assert response.status_code == 400 # Bad Request + + +@pytest.mark.anyio +async def test_logout_redirect(get_app): + async with AsyncClient(app=get_app(), base_url="http://test") as client: + response = await client.get("/oauth2/logout") + assert response.status_code == 307 # Redirect diff --git a/tests/test_security.py b/tests/test_security.py new file mode 100644 index 0000000..9c8fa1f --- /dev/null +++ b/tests/test_security.py @@ -0,0 +1,29 @@ +import pytest + +from fastapi_oauth2.security import OAuth2 +from fastapi_oauth2.security import OAuth2AuthorizationCodeBearer +from fastapi_oauth2.security import OAuth2PasswordBearer + + +@pytest.mark.anyio +async def test_security_oauth2(get_app): + try: + get_app(OAuth2()) + except (TypeError, Exception): + assert False + + +@pytest.mark.anyio +async def test_security_oauth2_password_bearer(get_app): + try: + get_app(OAuth2PasswordBearer(tokenUrl="/test")) + except (TypeError, Exception): + assert False + + +@pytest.mark.anyio +async def test_security_oauth2_authentication_code_bearer(get_app): + try: + get_app(OAuth2AuthorizationCodeBearer(authorizationUrl="/test", tokenUrl="/test")) + except (TypeError, Exception): + assert False