From 3fe844ec8f1dd955540081241690fe5374dee6df Mon Sep 17 00:00:00 2001 From: Yarne Coppens Date: Tue, 10 Sep 2024 20:22:24 +0200 Subject: [PATCH] Have working auth flow with test data --- src/data_connection.py | 13 +++++++++-- src/main.py | 43 ++++++++++++++++++++++++++++++++---- src/modules/auth_manager.py | 11 +++++++++ src/modules/db_connection.py | 8 ++++++- 4 files changed, 68 insertions(+), 7 deletions(-) create mode 100644 src/modules/auth_manager.py diff --git a/src/data_connection.py b/src/data_connection.py index 7b24a9c..cd7a0d5 100644 --- a/src/data_connection.py +++ b/src/data_connection.py @@ -1,5 +1,9 @@ -from src.classes import product_classes +from src.classes import product_classes, user_classes from src.modules import db_connection +from sqlmodel import Session +from typing import Annotated +from src.modules import auth_manager +from fastapi import Depends melon = product_classes.Product(name="Meloen", price=2.0, barcode=1000 ,image_filename="melon") @@ -26,7 +30,12 @@ def get_single_product(barcode: int) -> product_classes.Product: for product in product_list: if product.barcode == barcode: return product - + +def get_user_by_username(session: Session, username: str) -> user_classes.UserInDB: + return db_connection.get_user_by_username(session, username) + + + def create_db_and_tables() -> None: db_connection.create_db_and_tables() diff --git a/src/main.py b/src/main.py index 98bea47..645c48e 100644 --- a/src/main.py +++ b/src/main.py @@ -1,14 +1,18 @@ -from fastapi import FastAPI +from typing import Annotated + +from fastapi import FastAPI, Depends, HTTPException, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm import os from contextlib import asynccontextmanager from sqlmodel import Session + from src.classes import product_classes, cash_classes, user_classes from src.config import definitions -from src.modules import price_to_cash_calculator +from src.modules import price_to_cash_calculator, auth_manager from shop_validators import image_validator from src import data_connection @@ -18,11 +22,12 @@ async def get_session(): def create_test_data(): user_1 = user_classes.UserInDB(username="yarninator", email="yarn@inator.com",full_name="Yarn Inator", disabled=False, hashed_password="abcdefghijklmnop") + user_2 = user_classes.UserInDB(username="lorinator", email="lor@inator.com", full_name="Lor Inator", disabled=False, hashed_password="abcdefghijklmnop") with Session(data_connection.get_db_engine()) as session: session.add(user_1) + session.add(user_2) session.commit() - @asynccontextmanager async def lifespan(app: FastAPI): # Startup @@ -33,6 +38,7 @@ async def lifespan(app: FastAPI): # Shutdown app = FastAPI(lifespan=lifespan) +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") origins = [ "*" #TODO change this @@ -70,4 +76,33 @@ def get_icon(icon_filename: str): def price_to_cash(price: int): cash_model = price_to_cash_calculator.price_to_cash_model(price) - return cash_model \ No newline at end of file + return cash_model + + +@app.post("/token") +def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], session: Session = Depends(get_session)): + user = data_connection.get_user_by_username(session, form_data.username) + if not user: + raise HTTPException(status_code=400, detail="Incorrect username or password") + hashed_password = auth_manager.hash_password(form_data.password) + if not hashed_password == user.hashed_password: + raise HTTPException(status_code=400, detail="Incorrect username or password") + + return {"access_token": user.username, "token_type": "bearer"} + + + +def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], session: Session = Depends(get_session)): + user = auth_manager.token_to_user(session, token) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + return user + +@app.get("/me", response_model=user_classes.UserPublic) +def get_me(current_user: Annotated[user_classes.UserInDB, Depends(get_current_user)]): + return current_user + diff --git a/src/modules/auth_manager.py b/src/modules/auth_manager.py new file mode 100644 index 0000000..429dd9f --- /dev/null +++ b/src/modules/auth_manager.py @@ -0,0 +1,11 @@ +from src.classes import user_classes +from src import data_connection + +#TODO change this +def hash_password(password: str): + return password + +#TODO change this +def token_to_user(session, token): + print(token) + return data_connection.get_user_by_username(session, token) \ No newline at end of file diff --git a/src/modules/db_connection.py b/src/modules/db_connection.py index 56ab8dc..06d3217 100644 --- a/src/modules/db_connection.py +++ b/src/modules/db_connection.py @@ -1,5 +1,6 @@ -from sqlmodel import create_engine, SQLModel +from sqlmodel import create_engine, SQLModel, Session, select from src.config import definitions +from src.classes import user_classes sqlite_url = definitions.SQLITE_URL @@ -7,6 +8,11 @@ sqlite_url = definitions.SQLITE_URL connect_args = {"check_same_thread": False} engine = create_engine(sqlite_url, echo=True, connect_args=connect_args) +def get_user_by_username(session: Session, username: str): + statement = select(user_classes.UserInDB).where(user_classes.UserInDB.username == username) + return session.exec(statement).first() + + def get_engine(): return engine