Have working auth flow with test data

This commit is contained in:
Yarne Coppens 2024-09-10 20:22:24 +02:00
parent f7cf0da546
commit 3fe844ec8f
4 changed files with 68 additions and 7 deletions

View file

@ -1,5 +1,9 @@
from src.classes import product_classes from src.classes import product_classes, user_classes
from src.modules import db_connection from src.modules import db_connection
from sqlmodel import Session
from typing import Annotated
from src.modules import auth_manager
from fastapi import Depends
melon = product_classes.Product(name="Meloen", price=2.0, barcode=1000 ,image_filename="melon") melon = product_classes.Product(name="Meloen", price=2.0, barcode=1000 ,image_filename="melon")
@ -27,6 +31,11 @@ def get_single_product(barcode: int) -> product_classes.Product:
if product.barcode == barcode: if product.barcode == barcode:
return product return product
def get_user_by_username(session: Session, username: str) -> user_classes.UserInDB:
return db_connection.get_user_by_username(session, username)
def create_db_and_tables() -> None: def create_db_and_tables() -> None:
db_connection.create_db_and_tables() db_connection.create_db_and_tables()

View file

@ -1,14 +1,18 @@
from fastapi import FastAPI from typing import Annotated
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from sqlmodel import Session from sqlmodel import Session
from src.classes import product_classes, cash_classes, user_classes from src.classes import product_classes, cash_classes, user_classes
from src.config import definitions from src.config import definitions
from src.modules import price_to_cash_calculator from src.modules import price_to_cash_calculator, auth_manager
from shop_validators import image_validator from shop_validators import image_validator
from src import data_connection from src import data_connection
@ -18,11 +22,12 @@ async def get_session():
def create_test_data(): def create_test_data():
user_1 = user_classes.UserInDB(username="yarninator", email="yarn@inator.com",full_name="Yarn Inator", disabled=False, hashed_password="abcdefghijklmnop") user_1 = user_classes.UserInDB(username="yarninator", email="yarn@inator.com",full_name="Yarn Inator", disabled=False, hashed_password="abcdefghijklmnop")
user_2 = user_classes.UserInDB(username="lorinator", email="lor@inator.com", full_name="Lor Inator", disabled=False, hashed_password="abcdefghijklmnop")
with Session(data_connection.get_db_engine()) as session: with Session(data_connection.get_db_engine()) as session:
session.add(user_1) session.add(user_1)
session.add(user_2)
session.commit() session.commit()
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# Startup # Startup
@ -33,6 +38,7 @@ async def lifespan(app: FastAPI):
# Shutdown # Shutdown
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
origins = [ origins = [
"*" #TODO change this "*" #TODO change this
@ -71,3 +77,32 @@ def price_to_cash(price: int):
cash_model = price_to_cash_calculator.price_to_cash_model(price) cash_model = price_to_cash_calculator.price_to_cash_model(price)
return cash_model return cash_model
@app.post("/token")
def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], session: Session = Depends(get_session)):
user = data_connection.get_user_by_username(session, form_data.username)
if not user:
raise HTTPException(status_code=400, detail="Incorrect username or password")
hashed_password = auth_manager.hash_password(form_data.password)
if not hashed_password == user.hashed_password:
raise HTTPException(status_code=400, detail="Incorrect username or password")
return {"access_token": user.username, "token_type": "bearer"}
def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], session: Session = Depends(get_session)):
user = auth_manager.token_to_user(session, token)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
return user
@app.get("/me", response_model=user_classes.UserPublic)
def get_me(current_user: Annotated[user_classes.UserInDB, Depends(get_current_user)]):
return current_user

View file

@ -0,0 +1,11 @@
from src.classes import user_classes
from src import data_connection
#TODO change this
def hash_password(password: str):
return password
#TODO change this
def token_to_user(session, token):
print(token)
return data_connection.get_user_by_username(session, token)

View file

@ -1,5 +1,6 @@
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine, SQLModel, Session, select
from src.config import definitions from src.config import definitions
from src.classes import user_classes
sqlite_url = definitions.SQLITE_URL sqlite_url = definitions.SQLITE_URL
@ -7,6 +8,11 @@ sqlite_url = definitions.SQLITE_URL
connect_args = {"check_same_thread": False} connect_args = {"check_same_thread": False}
engine = create_engine(sqlite_url, echo=True, connect_args=connect_args) engine = create_engine(sqlite_url, echo=True, connect_args=connect_args)
def get_user_by_username(session: Session, username: str):
statement = select(user_classes.UserInDB).where(user_classes.UserInDB.username == username)
return session.exec(statement).first()
def get_engine(): def get_engine():
return engine return engine