diff --git a/README.md b/README.md index 490cbe2..a36c6e2 100644 --- a/README.md +++ b/README.md @@ -9,12 +9,6 @@ FastAPI OAuth2 is a middleware-based social authentication mechanism supporting several auth providers. It depends on the [social-core](https://github.com/python-social-auth/social-core) authentication backends. -## Features to be implemented - -- Use multiple OAuth2 providers at the same time - * There need to be provided a way to configure the OAuth2 for multiple providers -- Customizable OAuth2 routes - ## Installation ```shell diff --git a/examples/demonstration/.env b/examples/demonstration/.env index a1c0106..25f028b 100644 --- a/examples/demonstration/.env +++ b/examples/demonstration/.env @@ -1,5 +1,10 @@ -OAUTH2_CLIENT_ID=eccd08d6736b7999a32a -OAUTH2_CLIENT_SECRET=642999c1c5f2b3df8b877afdc78252ef5b594d31 +# These id and secret are generated especially for testing purposes, +# if you have your own, please use them, otherwise you can use these. +OAUTH2_GITHUB_CLIENT_ID=eccd08d6736b7999a32a +OAUTH2_GITHUB_CLIENT_SECRET=642999c1c5f2b3df8b877afdc78252ef5b594d31 + +OAUTH2_GOOGLE_CLIENT_ID=105851609656-uueuan570963mnnf4288nv40eieh9f5l.apps.googleusercontent.com +OAUTH2_GOOGLE_CLIENT_SECRET=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 JWT_SECRET=secret JWT_ALGORITHM=HS256 diff --git a/examples/demonstration/config.py b/examples/demonstration/config.py index 935c2b1..be64b0f 100644 --- a/examples/demonstration/config.py +++ b/examples/demonstration/config.py @@ -2,6 +2,7 @@ from dotenv import load_dotenv from social_core.backends.github import GithubOAuth2 +from social_core.backends.google import GoogleOAuth2 from fastapi_oauth2.claims import Claims from fastapi_oauth2.client import OAuth2Client @@ -17,14 +18,22 @@ clients=[ OAuth2Client( backend=GithubOAuth2, - client_id=os.getenv("OAUTH2_CLIENT_ID"), - client_secret=os.getenv("OAUTH2_CLIENT_SECRET"), - # redirect_uri="http://127.0.0.1:8000/", + client_id=os.getenv("OAUTH2_GITHUB_CLIENT_ID"), + client_secret=os.getenv("OAUTH2_GITHUB_CLIENT_SECRET"), scope=["user:email"], claims=Claims( picture="avatar_url", identity=lambda user: "%s:%s" % (user.get("provider"), user.get("id")), ), ), + OAuth2Client( + backend=GoogleOAuth2, + client_id=os.getenv("OAUTH2_GOOGLE_CLIENT_ID"), + client_secret=os.getenv("OAUTH2_GOOGLE_CLIENT_SECRET"), + scope=["openid", "profile", "email"], + claims=Claims( + identity=lambda user: "%s:%s" % (user.get("provider"), user.get("sub")), + ), + ), ] ) diff --git a/examples/demonstration/main.py b/examples/demonstration/main.py index e657bf1..4b78238 100644 --- a/examples/demonstration/main.py +++ b/examples/demonstration/main.py @@ -1,5 +1,6 @@ from fastapi import APIRouter from fastapi import FastAPI +from fastapi.staticfiles import StaticFiles from sqlalchemy.orm import Session from config import oauth2_config @@ -24,16 +25,18 @@ async def on_auth(auth: Auth, user: User): db: Session = next(get_db()) query = db.query(UserModel) if user.identity and not query.filter_by(identity=user.identity).first(): + # create a local user by OAuth2 user's data if it does not exist yet UserModel(**{ - "identity": user.get("identity"), - "username": user.get("username"), - "image": user.get("image"), - "email": user.get("email"), - "name": user.get("name"), + "identity": user.identity, # User property + "username": user.get("username"), # custom attribute + "name": user.display_name, # User property + "image": user.picture, # User property + "email": user.email, # User property }).save(db) app = FastAPI() app.include_router(app_router) app.include_router(oauth2_router) +app.mount("/static", StaticFiles(directory="static"), name="static") app.add_middleware(OAuth2Middleware, config=oauth2_config, callback=on_auth) diff --git a/examples/demonstration/static/github.svg b/examples/demonstration/static/github.svg new file mode 100644 index 0000000..75a94ed --- /dev/null +++ b/examples/demonstration/static/github.svg @@ -0,0 +1,5 @@ + + + + \ No newline at end of file diff --git a/examples/demonstration/static/google-oauth2.svg b/examples/demonstration/static/google-oauth2.svg new file mode 100644 index 0000000..ac18388 --- /dev/null +++ b/examples/demonstration/static/google-oauth2.svg @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/examples/demonstration/templates/index.html b/examples/demonstration/templates/index.html index 9a8b81d..caea8e5 100644 --- a/examples/demonstration/templates/index.html +++ b/examples/demonstration/templates/index.html @@ -21,11 +21,15 @@ Simulate Login - - - - - + {% for provider in request.auth.clients %} + + {{ provider }} icon + + {% endfor %} {% endif %} @@ -33,6 +37,14 @@ style="display: flex; flex-direction: column; align-items: center; justify-content: center; height: calc(100vh - 70px);"> {% if request.user.is_authenticated %}

Hi, {{ request.user.display_name }}

+

+ You're signed in using + {% if request.auth.provider %} + external {{ request.auth.provider.provider }} OAuth2 provider. + {% else %} + local authentication system. + {% endif %} +

This is what your JWT contains currently

{{ json.dumps(request.user, indent=4) }}
{% else %} diff --git a/setup.cfg b/setup.cfg index ed46db8..81b6a1d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,7 +27,7 @@ license_files = LICENSE platforms = unix, linux, osx, win32 classifiers = Operating System :: OS Independent - Development Status :: 2 - Pre-Alpha + Development Status :: 3 - Alpha Framework :: FastAPI Programming Language :: Python Programming Language :: Python :: 3 diff --git a/src/fastapi_oauth2/__init__.py b/src/fastapi_oauth2/__init__.py index a390618..5186ae4 100644 --- a/src/fastapi_oauth2/__init__.py +++ b/src/fastapi_oauth2/__init__.py @@ -1 +1 @@ -__version__ = "1.0.0-alpha.1" +__version__ = "1.0.0-alpha.2" diff --git a/src/fastapi_oauth2/core.py b/src/fastapi_oauth2/core.py index a9e7291..3a4ea18 100644 --- a/src/fastapi_oauth2/core.py +++ b/src/fastapi_oauth2/core.py @@ -10,6 +10,7 @@ import httpx from oauthlib.oauth2 import WebApplicationClient +from oauthlib.oauth2.rfc6749.errors import CustomOAuth2Error from social_core.backends.oauth import BaseOAuth2 from social_core.strategy import BaseStrategy from starlette.exceptions import HTTPException @@ -46,9 +47,10 @@ class OAuth2Core: client_id: str = None client_secret: str = None - callback_url: Optional[str] = None scope: Optional[List[str]] = None claims: Optional[Claims] = None + provider: str = None + redirect_uri: str = None backend: BaseOAuth2 = None _oauth_client: Optional[WebApplicationClient] = None @@ -108,9 +110,12 @@ async def token_redirect(self, request: Request) -> RedirectResponse: auth = httpx.BasicAuth(self.client_id, self.client_secret) async with httpx.AsyncClient() as session: response = await session.post(token_url, headers=headers, content=content, auth=auth) - token = self.oauth_client.parse_request_body_response(json.dumps(response.json())) - token_data = self.standardize(self.backend.user_data(token.get("access_token"))) - access_token = request.auth.jwt_create(token_data) + try: + token = self.oauth_client.parse_request_body_response(json.dumps(response.json())) + token_data = self.standardize(self.backend.user_data(token.get("access_token"))) + access_token = request.auth.jwt_create(token_data) + except (CustomOAuth2Error, Exception) as e: + raise OAuth2LoginError(400, str(e)) response = RedirectResponse(self.redirect_uri or request.base_url) response.set_cookie( diff --git a/src/fastapi_oauth2/middleware.py b/src/fastapi_oauth2/middleware.py index c921f7b..5dd5eb1 100644 --- a/src/fastapi_oauth2/middleware.py +++ b/src/fastapi_oauth2/middleware.py @@ -6,7 +6,6 @@ from typing import Dict from typing import List from typing import Optional -from typing import Sequence from typing import Tuple from typing import Union @@ -39,16 +38,15 @@ class Auth(AuthCredentials): scopes: List[str] clients: Dict[str, OAuth2Core] = {} - provider: str - default_provider: str = "local" + _provider: OAuth2Core = None - def __init__( - self, - scopes: Optional[Sequence[str]] = None, - provider: str = default_provider, - ) -> None: - super().__init__(scopes) - self.provider = provider + @property + def provider(self) -> Union[OAuth2Core, None]: + return self._provider + + @provider.setter + def provider(self, identifier) -> None: + self._provider = self.clients.get(identifier) @classmethod def set_http(cls, http: bool) -> None: @@ -145,18 +143,16 @@ async def authenticate(self, request: Request) -> Optional[Tuple[Auth, User]]: return Auth(), User() user = User(Auth.jwt_decode(param)) - user.update(provider=user.get("provider", Auth.default_provider)) - auth = Auth(user.pop("scope", []), user.get("provider")) - client = Auth.clients.get(auth.provider) - claims = client.claims if client else Claims() - user = user.use_claims(claims) + auth = Auth(user.pop("scope", [])) + auth.provider = user.get("provider") + claims = auth.provider.claims if auth.provider else {} # Call the callback function on authentication if callable(self.callback): - coroutine = self.callback(auth, user) + coroutine = self.callback(auth, user.use_claims(claims)) if issubclass(type(coroutine), Awaitable): await coroutine - return auth, user + return auth, user.use_claims(claims) class OAuth2Middleware: diff --git a/src/fastapi_oauth2/security.py b/src/fastapi_oauth2/security.py index 0f5d3b3..fddc067 100644 --- a/src/fastapi_oauth2/security.py +++ b/src/fastapi_oauth2/security.py @@ -1,8 +1,4 @@ -from typing import Any -from typing import Callable -from typing import Dict from typing import Optional -from typing import Tuple from typing import Type from fastapi.security import OAuth2 as FastAPIOAuth2 @@ -12,32 +8,29 @@ from starlette.requests import Request -def use_cookies(cls: Type[FastAPIOAuth2]) -> Callable[[Tuple[Any], Dict[str, Any]], FastAPIOAuth2]: - """OAuth2 classes wrapped with this decorator will use cookies for the Authorization header.""" +class OAuth2Cookie(type): + """OAuth2 classes using this metaclass will use cookies for the Authorization header.""" + + def __new__(metacls, name, bases, attrs) -> Type: + instance = super().__new__(metacls, name, bases, attrs) - def _use_cookies(*args, **kwargs) -> FastAPIOAuth2: async def __call__(self: FastAPIOAuth2, request: Request) -> Optional[str]: authorization = request.headers.get("Authorization", request.cookies.get("Authorization")) if authorization: request._headers = Headers({**request.headers, "Authorization": authorization}) - return await super(cls, self).__call__(request) - - cls.__call__ = __call__ - return cls(*args, **kwargs) + return await instance.__base__.__call__(self, request) - return _use_cookies + instance.__call__ = __call__ + return instance -@use_cookies -class OAuth2(FastAPIOAuth2): +class OAuth2(FastAPIOAuth2, metaclass=OAuth2Cookie): """Wrapper class of the `fastapi.security.OAuth2` class.""" -@use_cookies -class OAuth2PasswordBearer(FastAPIPasswordBearer): +class OAuth2PasswordBearer(FastAPIPasswordBearer, metaclass=OAuth2Cookie): """Wrapper class of the `fastapi.security.OAuth2PasswordBearer` class.""" -@use_cookies -class OAuth2AuthorizationCodeBearer(FastAPICodeBearer): +class OAuth2AuthorizationCodeBearer(FastAPICodeBearer, metaclass=OAuth2Cookie): """Wrapper class of the `fastapi.security.OAuth2AuthorizationCodeBearer` class.""" diff --git a/tests/conftest.py b/tests/conftest.py index aedb52a..b96e6c5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,18 @@ import pytest import social_core.backends as backends +from fastapi import APIRouter +from fastapi import Depends +from fastapi import FastAPI +from fastapi import Request +from social_core.backends.github import GithubOAuth2 from social_core.backends.oauth import BaseOAuth2 +from starlette.responses import Response + +from fastapi_oauth2.client import OAuth2Client +from fastapi_oauth2.middleware import OAuth2Middleware +from fastapi_oauth2.router import router as oauth2_router +from fastapi_oauth2.security import OAuth2 package_path = backends.__path__[0] @@ -24,3 +35,52 @@ def backends(): except ImportError: continue return backend_instances + + +@pytest.fixture +def get_app(): + def fixture_wrapper(authentication: OAuth2 = None): + if not authentication: + authentication = OAuth2() + + oauth2 = authentication + application = FastAPI() + app_router = APIRouter() + + @app_router.get("/user") + def user(request: Request, _: str = Depends(oauth2)): + return request.user + + @app_router.get("/auth") + def auth(request: Request): + access_token = request.auth.jwt_create({ + "name": "test", + "sub": "test", + "id": "test", + }) + response = Response() + response.set_cookie( + "Authorization", + value=f"Bearer {access_token}", + max_age=request.auth.expires, + expires=request.auth.expires, + httponly=request.auth.http, + ) + return response + + application.include_router(app_router) + application.include_router(oauth2_router) + application.add_middleware(OAuth2Middleware, config={ + "allow_http": True, + "clients": [ + OAuth2Client( + backend=GithubOAuth2, + client_id="test_id", + client_secret="test_secret", + ), + ], + }) + + return application + + return fixture_wrapper diff --git a/tests/test_backends.py b/tests/test_backends.py new file mode 100644 index 0000000..47a91d6 --- /dev/null +++ b/tests/test_backends.py @@ -0,0 +1,17 @@ +import pytest + +from fastapi_oauth2.client import OAuth2Client +from fastapi_oauth2.core import OAuth2Core + + +@pytest.mark.anyio +async def test_core_init_with_all_backends(backends): + for backend in backends: + try: + OAuth2Core(OAuth2Client( + backend=backend, + client_id="test_client_id", + client_secret="test_client_secret", + )) + except (NotImplementedError, Exception): + assert False diff --git a/tests/test_middleware.py b/tests/test_middleware.py new file mode 100644 index 0000000..e33c6b7 --- /dev/null +++ b/tests/test_middleware.py @@ -0,0 +1,28 @@ +import pytest +from httpx import AsyncClient + + +@pytest.mark.anyio +async def test_middleware_on_authentication(get_app): + async with AsyncClient(app=get_app(), base_url="http://test") as client: + response = await client.get("/user") + assert response.status_code == 403 # Forbidden + + await client.get("/auth") # Simulate login + + response = await client.get("/user") + assert response.status_code == 200 # OK + + +@pytest.mark.anyio +async def test_middleware_on_logout(get_app): + async with AsyncClient(app=get_app(), base_url="http://test") as client: + await client.get("/auth") # Simulate login + + response = await client.get("/user") + assert response.status_code == 200 # OK + + await client.get("/oauth2/logout") # Perform logout + + response = await client.get("/user") + assert response.status_code == 403 # Forbidden diff --git a/tests/test_oauth2_middleware.py b/tests/test_oauth2_middleware.py deleted file mode 100644 index 27d0752..0000000 --- a/tests/test_oauth2_middleware.py +++ /dev/null @@ -1,87 +0,0 @@ -import pytest -from fastapi import APIRouter -from fastapi import Depends -from fastapi import FastAPI -from fastapi import Request -from httpx import AsyncClient -from social_core.backends.github import GithubOAuth2 -from starlette.responses import Response - -from fastapi_oauth2.client import OAuth2Client -from fastapi_oauth2.core import OAuth2Core -from fastapi_oauth2.middleware import OAuth2Middleware -from fastapi_oauth2.router import router as oauth2_router -from fastapi_oauth2.security import OAuth2 - -app = FastAPI() -oauth2 = OAuth2() -app_router = APIRouter() - - -@app_router.get("/user") -def user(request: Request, _: str = Depends(oauth2)): - return request.user - - -@app_router.get("/auth") -def auth(request: Request): - access_token = request.auth.jwt_create({ - "name": "test", - "sub": "test", - "id": "test", - }) - response = Response() - response.set_cookie( - "Authorization", - value=f"Bearer {access_token}", - max_age=request.auth.expires, - expires=request.auth.expires, - httponly=request.auth.http, - ) - return response - - -app.include_router(app_router) -app.include_router(oauth2_router) -app.add_middleware(OAuth2Middleware, config={ - "allow_http": True, - "clients": [ - OAuth2Client( - backend=GithubOAuth2, - client_id="test_id", - client_secret="test_secret", - ), - ], -}) - - -@pytest.mark.anyio -async def test_auth_redirect(): - async with AsyncClient(app=app, base_url="http://test") as client: - response = await client.get("/oauth2/github/auth") - assert response.status_code == 303 # Redirect - - -@pytest.mark.anyio -async def test_authenticated_request(): - async with AsyncClient(app=app, base_url="http://test") as client: - response = await client.get("/user") - assert response.status_code == 403 # Forbidden - - await client.get("/auth") # Simulate login - - response = await client.get("/user") - assert response.status_code == 200 # OK - - -@pytest.mark.anyio -async def test_core_init(backends): - for backend in backends: - try: - OAuth2Core(OAuth2Client( - backend=backend, - client_id="test_client_id", - client_secret="test_client_secret", - )) - except (NotImplementedError, Exception): - assert False diff --git a/tests/test_router.py b/tests/test_router.py new file mode 100644 index 0000000..084f459 --- /dev/null +++ b/tests/test_router.py @@ -0,0 +1,26 @@ +import pytest +from httpx import AsyncClient + + +@pytest.mark.anyio +async def test_auth_redirect(get_app): + async with AsyncClient(app=get_app(), base_url="http://test") as client: + response = await client.get("/oauth2/github/auth") + assert response.status_code == 303 # Redirect + + +@pytest.mark.anyio +async def test_token_redirect(get_app): + async with AsyncClient(app=get_app(), base_url="http://test") as client: + response = await client.get("/oauth2/github/token") + assert response.status_code == 400 # Bad Request + + response = await client.get("/oauth2/github/token?state=test&code=test") + assert response.status_code == 400 # Bad Request + + +@pytest.mark.anyio +async def test_logout_redirect(get_app): + async with AsyncClient(app=get_app(), base_url="http://test") as client: + response = await client.get("/oauth2/logout") + assert response.status_code == 307 # Redirect diff --git a/tests/test_security.py b/tests/test_security.py new file mode 100644 index 0000000..9c8fa1f --- /dev/null +++ b/tests/test_security.py @@ -0,0 +1,29 @@ +import pytest + +from fastapi_oauth2.security import OAuth2 +from fastapi_oauth2.security import OAuth2AuthorizationCodeBearer +from fastapi_oauth2.security import OAuth2PasswordBearer + + +@pytest.mark.anyio +async def test_security_oauth2(get_app): + try: + get_app(OAuth2()) + except (TypeError, Exception): + assert False + + +@pytest.mark.anyio +async def test_security_oauth2_password_bearer(get_app): + try: + get_app(OAuth2PasswordBearer(tokenUrl="/test")) + except (TypeError, Exception): + assert False + + +@pytest.mark.anyio +async def test_security_oauth2_authentication_code_bearer(get_app): + try: + get_app(OAuth2AuthorizationCodeBearer(authorizationUrl="/test", tokenUrl="/test")) + except (TypeError, Exception): + assert False