diff --git a/.gitignore b/.gitignore index 78102b5..935a175 100644 --- a/.gitignore +++ b/.gitignore @@ -161,3 +161,4 @@ cython_debug/ #.idea/ .vscode/ +db/database.db diff --git a/src/classes/user_classes.py b/src/classes/user_classes.py index 619312b..70afec2 100644 --- a/src/classes/user_classes.py +++ b/src/classes/user_classes.py @@ -9,7 +9,8 @@ class UserBase(SQLModel): class UserPublic(UserBase): username: str email: str - + full_name: str + class UserInDB(UserBase, table=True): id: int | None = Field(default=None, primary_key=True) hashed_password: str \ No newline at end of file diff --git a/src/data_connection.py b/src/data_connection.py index f36c5ae..160ff3c 100644 --- a/src/data_connection.py +++ b/src/data_connection.py @@ -1,4 +1,5 @@ from src.classes import product_classes +from src.modules import db_connection melon = product_classes.Product(name="Meloen", price=2.0, barcode=1000 ,image_filename="melon") @@ -24,4 +25,7 @@ def get_all_products() -> list[product_classes.Product]: def get_single_product(barcode: int) -> product_classes.Product: for product in product_list: if product.barcode == barcode: - return product \ No newline at end of file + return product + +def create_db_and_tables() -> None: + db_connection.create_db_and_tables() \ No newline at end of file diff --git a/src/main.py b/src/main.py index acbd8e3..cb378b5 100644 --- a/src/main.py +++ b/src/main.py @@ -2,14 +2,23 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse import os +from contextlib import asynccontextmanager -from src.classes import product_classes, cash_classes + +from src.classes import product_classes, cash_classes, user_classes from src.config import definitions from src.modules import price_to_cash_calculator from shop_validators import image_validator from src import data_connection -app = FastAPI() +@asynccontextmanager +async def lifespan(app: FastAPI): + # Startup + data_connection.create_db_and_tables() + yield + # Shutdown + +app = FastAPI(lifespan=lifespan) origins = [ "*" #TODO change this @@ -23,7 +32,6 @@ app.add_middleware( allow_headers=["*"], ) - @app.get("/") def read_root(): return {"Hello": "World"} diff --git a/src/modules/db_connection.py b/src/modules/db_connection.py index a10657c..9ce784b 100644 --- a/src/modules/db_connection.py +++ b/src/modules/db_connection.py @@ -1,11 +1,14 @@ -from sqlmodel import create_engine +from sqlmodel import create_engine, SQLModel from src.config import definitions sqlite_url = definitions.SQLITE_URL connect_args = {"check_same_thread": False} -engine = create_engine(sqlite_url, echo=False, connect_args=connect_args) +engine = create_engine(sqlite_url, echo=True, connect_args=connect_args) def get_engine(): - return engine \ No newline at end of file + return engine + +def create_db_and_tables() -> None: + SQLModel.metadata.create_all(engine) \ No newline at end of file