Made db session persist through the entire web request

This commit is contained in:
Yarne Coppens 2024-08-15 11:07:11 +02:00
parent 1bc8733cf9
commit 92849b898b
3 changed files with 139 additions and 133 deletions

View file

@ -3,6 +3,7 @@ from datetime import date, timedelta, datetime
from pydantic import BaseModel from pydantic import BaseModel
from sqlmodel import Session from sqlmodel import Session
from fastapi import FastAPI, Depends from fastapi import FastAPI, Depends
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@ -11,14 +12,24 @@ from src.classes import boardgame_classes, play_classes, statistic_classes
from src.modules import data_connection from src.modules import data_connection
from src.filters import boardgame_filters, play_filters from src.filters import boardgame_filters, play_filters
def get_session():
with Session(data_connection.get_db_engine()) as session:
yield session
def refresh_data():
data_connection.delete_database()
with Session(data_connection.get_db_engine()) as session:
data_connection.get_user_collection(session)
data_connection.get_user_owned_collection(session)
data_connection.get_user_wishlist_collection(session)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# Startup # Startup
data_connection.delete_database() #data_connection.delete_database()
data_connection.create_db_and_tables() data_connection.create_db_and_tables()
# data_connection.get_user_collection()
# data_connection.get_user_owned_collection() #refresh_data()
data_connection.get_user_wishlist_collection()
yield yield
# Shutdown # Shutdown
@ -56,13 +67,13 @@ def read_root():
return {"Hello": "World"} return {"Hello": "World"}
@app.get("/boardgame", response_model=Union[boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion]) @app.get("/boardgame", response_model=Union[boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion])
def get_boardgame_by_id(id: int): def get_boardgame_by_id(id: int, session: Session = Depends(get_session)):
requested_boardgame: Union[boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion] = data_connection.get_boardgame(id) requested_boardgame: Union[boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion] = data_connection.get_boardgame(session, id)
return requested_boardgame return requested_boardgame
@app.get("/owned", response_model=list[Union[boardgame_classes.OwnedBoardGame, boardgame_classes.OwnedBoardGameExpansion]]) @app.get("/owned", response_model=list[Union[boardgame_classes.OwnedBoardGame, boardgame_classes.OwnedBoardGameExpansion]])
def get_owned_collection(query: ExpansionFilteringParams = Depends()): def get_owned_collection(query: ExpansionFilteringParams = Depends(), session: Session = Depends(get_session)):
to_return_boardgames = data_connection.get_user_owned_collection() to_return_boardgames = data_connection.get_user_owned_collection(session)
if query.filter_expansions_out: if query.filter_expansions_out:
to_return_boardgames = boardgame_filters.filter_expansions_out(to_return_boardgames) to_return_boardgames = boardgame_filters.filter_expansions_out(to_return_boardgames)
@ -74,9 +85,9 @@ def get_owned_collection(query: ExpansionFilteringParams = Depends()):
@app.get("/wishlist", response_model=list[Union[boardgame_classes.WishlistBoardGame, boardgame_classes.WishlistBoardGameExpansion]]) @app.get("/wishlist", response_model=list[Union[boardgame_classes.WishlistBoardGame, boardgame_classes.WishlistBoardGameExpansion]])
def get_wishlist_collection(priority: int = 0, query: ExpansionFilteringParams = Depends()): def get_wishlist_collection(priority: int = 0, query: ExpansionFilteringParams = Depends(), session: Session = Depends(get_session)):
to_return_boardgames = data_connection.get_user_wishlist_collection(priority) to_return_boardgames = data_connection.get_user_wishlist_collection(session, priority)
if query.filter_expansions_out: if query.filter_expansions_out:
to_return_boardgames = boardgame_filters.filter_expansions_out(to_return_boardgames) to_return_boardgames = boardgame_filters.filter_expansions_out(to_return_boardgames)
@ -88,9 +99,9 @@ def get_wishlist_collection(priority: int = 0, query: ExpansionFilteringParams =
@app.get("/plays", response_model=list[play_classes.PlayPublicWithPlayers]) @app.get("/plays", response_model=list[play_classes.PlayPublicWithPlayers])
def get_plays(query: ExpansionFilteringParams = Depends(), boardgame_id: int = -1): def get_plays(query: ExpansionFilteringParams = Depends(), boardgame_id: int = -1, session: Session = Depends(get_session)):
requested_plays = data_connection.get_plays() requested_plays = data_connection.get_plays(session)
if query.filter_expansions_out: if query.filter_expansions_out:
requested_plays = play_filters.filter_expansions_out(requested_plays) requested_plays = play_filters.filter_expansions_out(requested_plays)
@ -105,15 +116,15 @@ def get_plays(query: ExpansionFilteringParams = Depends(), boardgame_id: int = -
@app.get('/players', response_model=list[play_classes.PlayPlayerPublic]) @app.get('/players', response_model=list[play_classes.PlayPlayerPublic])
def get_players_from_play(play_id: int): def get_players_from_play(play_id: int, session: Session = Depends(get_session)):
requested_players = data_connection.get_players_from_play(play_id) requested_players = data_connection.get_players_from_play(session, play_id)
return requested_players return requested_players
@app.get('/statistics/amount_of_games', response_model=statistic_classes.NumberStatistic) @app.get('/statistics/amount_of_games', response_model=statistic_classes.NumberStatistic)
def get_amount_of_games(query: ExpansionFilteringParams = Depends()): def get_amount_of_games(query: ExpansionFilteringParams = Depends(), session: Session = Depends(get_session)):
owned_collection = data_connection.get_user_owned_collection() owned_collection = data_connection.get_user_owned_collection(session)
if query.filter_expansions_out: if query.filter_expansions_out:
owned_collection = boardgame_filters.filter_expansions_out(owned_collection) owned_collection = boardgame_filters.filter_expansions_out(owned_collection)
@ -130,14 +141,14 @@ def get_amount_of_games(query: ExpansionFilteringParams = Depends()):
return statistic_to_return return statistic_to_return
@app.get('/statistics/amount_of_games_over_time', response_model=statistic_classes.TimeLineStatistic) @app.get('/statistics/amount_of_games_over_time', response_model=statistic_classes.TimeLineStatistic)
def get_amount_of_games_over_time(day_step: int = 1, filter_expansions_out: bool = False, only_expansions: bool = False): def get_amount_of_games_over_time(day_step: int = 1, filter_expansions_out: bool = False, only_expansions: bool = False, session: Session = Depends(get_session)):
def daterange(start_date: date, end_date: date, day_step): def daterange(start_date: date, end_date: date, day_step):
days = int((end_date - start_date).days) days = int((end_date - start_date).days)
for n in range(0, days, day_step): for n in range(0, days, day_step):
yield start_date + timedelta(n) yield start_date + timedelta(n)
games_in_owned_collection = data_connection.get_user_owned_collection() games_in_owned_collection = data_connection.get_user_owned_collection(session)
games_in_owned_collection.sort(key=lambda x: x.acquisition_date) games_in_owned_collection.sort(key=lambda x: x.acquisition_date)
start_date = games_in_owned_collection[0].acquisition_date start_date = games_in_owned_collection[0].acquisition_date
@ -164,8 +175,8 @@ def get_amount_of_games_over_time(day_step: int = 1, filter_expansions_out: bool
return statistic_to_return return statistic_to_return
@app.get('/statistics/games_played_per_year', response_model=statistic_classes.TimeLineStatistic) @app.get('/statistics/games_played_per_year', response_model=statistic_classes.TimeLineStatistic)
def get_amount_of_games_played_per_year(query: ExpansionFilteringParams = Depends()): def get_amount_of_games_played_per_year(query: ExpansionFilteringParams = Depends(), session: Session = Depends(get_session)):
all_plays = data_connection.get_plays() all_plays = data_connection.get_plays(session)
all_plays.sort(key= lambda x: x.play_date) all_plays.sort(key= lambda x: x.play_date)
@ -204,9 +215,9 @@ def get_amount_of_games_played_per_year(query: ExpansionFilteringParams = Depend
@app.get('/statistics/most_expensive_games', response_model=statistic_classes.GamesStatistic) @app.get('/statistics/most_expensive_games', response_model=statistic_classes.GamesStatistic)
def get_most_expensive_game(query: ExpansionFilteringParams = Depends(), top_amount: int = 10): def get_most_expensive_game(query: ExpansionFilteringParams = Depends(), top_amount: int = 10, session: Session = Depends(get_session)):
most_expensive_games: list[Union[boardgame_classes.OwnedBoardGame, boardgame_classes.OwnedBoardGameExpansion]] = data_connection.get_user_owned_collection() most_expensive_games: list[Union[boardgame_classes.OwnedBoardGame, boardgame_classes.OwnedBoardGameExpansion]] = data_connection.get_user_owned_collection(session)
if query.filter_expansions_out: if query.filter_expansions_out:
most_expensive_games = boardgame_filters.filter_expansions_out(most_expensive_games) most_expensive_games = boardgame_filters.filter_expansions_out(most_expensive_games)
@ -228,5 +239,5 @@ def get_most_expensive_game(query: ExpansionFilteringParams = Depends(), top_amo
return statistic_to_return return statistic_to_return
@app.get('/statistics/shelf_of_shame', response_model=statistic_classes.GamesStatistic) @app.get('/statistics/shelf_of_shame', response_model=statistic_classes.GamesStatistic)
def get_shelf_of_shame(query: ExpansionFilteringParams = Depends()): def get_shelf_of_shame(query: ExpansionFilteringParams = Depends(), session: Session = Depends(get_session)):
pass pass

View file

@ -1,4 +1,5 @@
from typing import Union from typing import Union
from sqlmodel import Session
from threading import Lock from threading import Lock
critical_function_lock = Lock() critical_function_lock = Lock()
@ -6,10 +7,13 @@ critical_function_lock = Lock()
from src.modules import bgg_connection, db_connection from src.modules import bgg_connection, db_connection
from src.classes import boardgame_classes, play_classes from src.classes import boardgame_classes, play_classes
def get_boardgame(boardgame_id: int) -> Union[boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion]: def get_db_engine():
return db_connection.get_engine()
def get_boardgame(session: Session, boardgame_id: int) -> Union[boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion]:
#Will check if it already exists in db, then it will get it from there #Will check if it already exists in db, then it will get it from there
boardgame_in_db = db_connection.get_boardgame(boardgame_id=boardgame_id) boardgame_in_db = db_connection.get_boardgame(session, boardgame_id=boardgame_id)
to_return_boardgame = None to_return_boardgame = None
@ -17,64 +21,64 @@ def get_boardgame(boardgame_id: int) -> Union[boardgame_classes.BoardGame, board
to_return_boardgame = boardgame_in_db to_return_boardgame = boardgame_in_db
else: else:
to_return_boardgame = bgg_connection.get_boardgame(boardgame_id) to_return_boardgame = bgg_connection.get_boardgame(boardgame_id)
db_connection.add_boardgame(to_return_boardgame) db_connection.add_boardgame(session, to_return_boardgame)
to_return_boardgame = db_connection.get_boardgame(boardgame_id) to_return_boardgame = db_connection.get_boardgame(session, boardgame_id)
return to_return_boardgame return to_return_boardgame
def get_multiple_boardgames(boardgame_ids: list[int]) -> list[Union[boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion]]: def get_multiple_boardgames(session: Session, boardgame_ids: list[int]) -> list[Union[boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion]]:
boardgames_in_db, boardgame_ids_missing = db_connection.get_multiple_boardgames(boardgame_ids=boardgame_ids) boardgames_in_db, boardgame_ids_missing = db_connection.get_multiple_boardgames(session, boardgame_ids=boardgame_ids)
if len(boardgame_ids_missing) != 0: if len(boardgame_ids_missing) != 0:
missing_boardgames = bgg_connection.get_multiple_boardgames(boardgame_ids_missing) missing_boardgames = bgg_connection.get_multiple_boardgames(boardgame_ids_missing)
db_connection.add_multiple_boardgames(missing_boardgames) db_connection.add_multiple_boardgames(session, missing_boardgames)
boardgames_in_db, boardgame_ids_missing = db_connection.get_multiple_boardgames(boardgame_ids=boardgame_ids) boardgames_in_db, boardgame_ids_missing = db_connection.get_multiple_boardgames(session, boardgame_ids=boardgame_ids)
return boardgames_in_db return boardgames_in_db
def get_user_collection() -> list[Union[boardgame_classes.BoardGame, boardgame_classes.OwnedBoardGame]]: def get_user_collection(session: Session, ) -> list[Union[boardgame_classes.BoardGame, boardgame_classes.OwnedBoardGame]]:
boardgames_from_db: list[boardgame_classes.BoardGame] = db_connection.get_all_boardgames(boardgame_classes.BoardGame) boardgames_from_db: list[boardgame_classes.BoardGame] = db_connection.get_all_boardgames(session, boardgame_classes.BoardGame)
boardgame_expansions_from_db: list[boardgame_classes.BoardGameExpansion] = db_connection.get_all_boardgames(boardgame_classes.BoardGameExpansion) boardgame_expansions_from_db: list[boardgame_classes.BoardGameExpansion] = db_connection.get_all_boardgames(session, boardgame_classes.BoardGameExpansion)
if len(boardgames_from_db) == 0 and len(boardgame_expansions_from_db) == 0: if len(boardgames_from_db) == 0 and len(boardgame_expansions_from_db) == 0:
boardgames = bgg_connection.get_user_collection() boardgames = bgg_connection.get_user_collection()
db_connection.add_multiple_boardgames(boardgames) db_connection.add_multiple_boardgames(session, boardgames)
boardgames_from_db: list[boardgame_classes.BoardGame] = db_connection.get_all_boardgames(boardgame_classes.BoardGame) boardgames_from_db: list[boardgame_classes.BoardGame] = db_connection.get_all_boardgames(session, boardgame_classes.BoardGame)
boardgame_expansions_from_db: list[boardgame_classes.BoardGameExpansion] = db_connection.get_all_boardgames(boardgame_classes.BoardGameExpansion) boardgame_expansions_from_db: list[boardgame_classes.BoardGameExpansion] = db_connection.get_all_boardgames(session, boardgame_classes.BoardGameExpansion)
return boardgames_from_db + boardgame_expansions_from_db return boardgames_from_db + boardgame_expansions_from_db
def get_user_owned_collection() -> list[Union[boardgame_classes.OwnedBoardGame, boardgame_classes.OwnedBoardGameExpansion]]: def get_user_owned_collection(session: Session, ) -> list[Union[boardgame_classes.OwnedBoardGame, boardgame_classes.OwnedBoardGameExpansion]]:
owned_boardgames_from_db: list[boardgame_classes.OwnedBoardGame] = db_connection.get_all_boardgames(boardgame_classes.OwnedBoardGame) owned_boardgames_from_db: list[boardgame_classes.OwnedBoardGame] = db_connection.get_all_boardgames(session, boardgame_classes.OwnedBoardGame)
owned_boardgame_expanions_from_db: list[boardgame_classes.OwnedBoardGameExpansion] = db_connection.get_all_boardgames(boardgame_classes.OwnedBoardGameExpansion) owned_boardgame_expanions_from_db: list[boardgame_classes.OwnedBoardGameExpansion] = db_connection.get_all_boardgames(session, boardgame_classes.OwnedBoardGameExpansion)
if len(owned_boardgames_from_db) == 0 and len(owned_boardgame_expanions_from_db) == 0: if len(owned_boardgames_from_db) == 0 and len(owned_boardgame_expanions_from_db) == 0:
owned_boardgames = bgg_connection.get_user_owned_collection() owned_boardgames = bgg_connection.get_user_owned_collection()
db_connection.add_multiple_boardgames(owned_boardgames) db_connection.add_multiple_boardgames(session, owned_boardgames)
owned_boardgames_from_db: list[boardgame_classes.OwnedBoardGame] = db_connection.get_all_boardgames(boardgame_classes.OwnedBoardGame) owned_boardgames_from_db: list[boardgame_classes.OwnedBoardGame] = db_connection.get_all_boardgames(session, boardgame_classes.OwnedBoardGame)
owned_boardgame_expanions_from_db: list[boardgame_classes.OwnedBoardGameExpansion] = db_connection.get_all_boardgames(boardgame_classes.OwnedBoardGameExpansion) owned_boardgame_expanions_from_db: list[boardgame_classes.OwnedBoardGameExpansion] = db_connection.get_all_boardgames(session, boardgame_classes.OwnedBoardGameExpansion)
return owned_boardgames_from_db + owned_boardgame_expanions_from_db return owned_boardgames_from_db + owned_boardgame_expanions_from_db
def get_user_wishlist_collection(wishlist_priority: int = 0) -> Union[list[boardgame_classes.WishlistBoardGame], list[boardgame_classes.WishlistBoardGameExpansion]]: def get_user_wishlist_collection(session: Session, wishlist_priority: int = 0) -> Union[list[boardgame_classes.WishlistBoardGame], list[boardgame_classes.WishlistBoardGameExpansion]]:
wishlisted_boardgames_from_db = db_connection.get_all_boardgames(boardgame_classes.WishlistBoardGame) wishlisted_boardgames_from_db = db_connection.get_all_boardgames(session, boardgame_classes.WishlistBoardGame)
wishlisted_boardgame_expansions_from_db = db_connection.get_all_boardgames(boardgame_classes.WishlistBoardGameExpansion) wishlisted_boardgame_expansions_from_db = db_connection.get_all_boardgames(session, boardgame_classes.WishlistBoardGameExpansion)
if len(wishlisted_boardgames_from_db) == 0 and len(wishlisted_boardgame_expansions_from_db) == 0: if len(wishlisted_boardgames_from_db) == 0 and len(wishlisted_boardgame_expansions_from_db) == 0:
wishlisted_boardgames = bgg_connection.get_user_wishlist_collection() wishlisted_boardgames = bgg_connection.get_user_wishlist_collection()
db_connection.add_multiple_boardgames(wishlisted_boardgames) db_connection.add_multiple_boardgames(session, wishlisted_boardgames)
wishlisted_boardgames_from_db = db_connection.get_all_boardgames(boardgame_classes.WishlistBoardGame) wishlisted_boardgames_from_db = db_connection.get_all_boardgames(session, boardgame_classes.WishlistBoardGame)
wishlisted_boardgame_expansions_from_db = db_connection.get_all_boardgames(boardgame_classes.WishlistBoardGameExpansion) wishlisted_boardgame_expansions_from_db = db_connection.get_all_boardgames(session, boardgame_classes.WishlistBoardGameExpansion)
to_return_boardgames = wishlisted_boardgames_from_db + wishlisted_boardgame_expansions_from_db to_return_boardgames = wishlisted_boardgames_from_db + wishlisted_boardgame_expansions_from_db
@ -84,16 +88,16 @@ def get_user_wishlist_collection(wishlist_priority: int = 0) -> Union[list[board
return to_return_boardgames return to_return_boardgames
def get_plays() -> list[play_classes.Play]: def get_plays(session: Session, ) -> list[play_classes.Play]:
plays_from_db = db_connection.get_plays() plays_from_db = db_connection.get_plays(session)
if len(plays_from_db) == 0: if len(plays_from_db) == 0:
all_plays = bgg_connection.get_plays() all_plays = bgg_connection.get_plays()
db_connection.add_multiple_plays(all_plays) db_connection.add_multiple_plays(session, all_plays)
plays_from_db = db_connection.get_plays() plays_from_db = db_connection.get_plays(session)
#Making sure all played board games are in table 'boardgames' #Making sure all played board games are in table 'boardgames'
#list + set to remove duplicates #list + set to remove duplicates
@ -106,19 +110,19 @@ def get_plays() -> list[play_classes.Play]:
assert len(list(filter(lambda x: x == None, played_boardgame_ids))) == 0, plays_from_db assert len(list(filter(lambda x: x == None, played_boardgame_ids))) == 0, plays_from_db
get_multiple_boardgames(played_boardgame_ids + played_expansion_ids) get_multiple_boardgames(session, played_boardgame_ids + played_expansion_ids)
return plays_from_db return plays_from_db
def get_players_from_play(play_id: int) -> list[play_classes.PlayPlayer]: def get_players_from_play(session: Session, play_id: int) -> list[play_classes.PlayPlayer]:
players_from_db = db_connection.get_players_from_play(play_id) players_from_db = db_connection.get_players_from_play(session, play_id)
if len(players_from_db) == 0: if len(players_from_db) == 0:
all_plays = bgg_connection.get_plays() all_plays = bgg_connection.get_plays()
db_connection.add_multiple_plays(all_plays) db_connection.add_multiple_plays(session, all_plays)
players_from_db = db_connection.get_players_from_play(play_id) players_from_db = db_connection.get_players_from_play(session, play_id)
return players_from_db return players_from_db

View file

@ -13,13 +13,15 @@ sqlite_url = definitions.SQLITE_URL
connect_args = {"check_same_thread": False} connect_args = {"check_same_thread": False}
engine = create_engine(sqlite_url, echo=True, connect_args=connect_args) engine = create_engine(sqlite_url, echo=True, connect_args=connect_args)
def add_boardgame(boardgame: Union[ def get_engine():
return engine
def add_boardgame(session: Session, boardgame: Union[
boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion, boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion,
boardgame_classes.OwnedBoardGame, boardgame_classes.OwnedBoardGameExpansion, boardgame_classes.OwnedBoardGame, boardgame_classes.OwnedBoardGameExpansion,
boardgame_classes.WishlistBoardGame, boardgame_classes.WishlistBoardGameExpansion]): boardgame_classes.WishlistBoardGame, boardgame_classes.WishlistBoardGameExpansion]):
with critical_function_lock: with critical_function_lock:
with Session(engine) as session:
is_boardgame_present = len(session.exec( is_boardgame_present = len(session.exec(
select(boardgame.__class__).where(boardgame.__class__.id == boardgame.id) select(boardgame.__class__).where(boardgame.__class__.id == boardgame.id)
).all()) != 0 ).all()) != 0
@ -29,13 +31,12 @@ def add_boardgame(boardgame: Union[
session.commit() session.commit()
def add_multiple_boardgames(boardgame_list: list[Union[ def add_multiple_boardgames(session: Session, boardgame_list: list[Union[
boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion, boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion,
boardgame_classes.OwnedBoardGame, boardgame_classes.OwnedBoardGameExpansion, boardgame_classes.OwnedBoardGame, boardgame_classes.OwnedBoardGameExpansion,
boardgame_classes.WishlistBoardGame, boardgame_classes.WishlistBoardGameExpansion]]): boardgame_classes.WishlistBoardGame, boardgame_classes.WishlistBoardGameExpansion]]):
with critical_function_lock: with critical_function_lock:
with Session(engine) as session:
for boardgame in boardgame_list: for boardgame in boardgame_list:
is_boardgame_present = len(session.exec( is_boardgame_present = len(session.exec(
select(boardgame.__class__).where(boardgame.__class__.id == boardgame.id) select(boardgame.__class__).where(boardgame.__class__.id == boardgame.id)
@ -46,10 +47,9 @@ def add_multiple_boardgames(boardgame_list: list[Union[
session.commit() session.commit()
def get_boardgame(boardgame_id: int) -> Union[ def get_boardgame(session: Session, boardgame_id: int) -> Union[
boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion]: boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion]:
with Session(engine) as session:
statement = select(boardgame_classes.BoardGame).where(boardgame_classes.BoardGame.id == boardgame_id) statement = select(boardgame_classes.BoardGame).where(boardgame_classes.BoardGame.id == boardgame_id)
base_boardgames = session.exec(statement).all() base_boardgames = session.exec(statement).all()
@ -67,10 +67,8 @@ def get_boardgame(boardgame_id: int) -> Union[
return boardgame return boardgame
def get_multiple_boardgames(boardgame_ids: list[int]) -> tuple[Union[ def get_multiple_boardgames(session: Session, boardgame_ids: list[int]) -> tuple[Union[
list[boardgame_classes.BoardGame], list[boardgame_classes.BoardGameExpansion]], list[int]]: list[boardgame_classes.BoardGame], list[boardgame_classes.BoardGameExpansion]], list[int]]:
with Session(engine) as session:
statement = select(boardgame_classes.BoardGame).where(boardgame_classes.BoardGame.id.in_(boardgame_ids)) statement = select(boardgame_classes.BoardGame).where(boardgame_classes.BoardGame.id.in_(boardgame_ids))
results = session.exec(statement) results = session.exec(statement)
@ -87,12 +85,11 @@ def get_multiple_boardgames(boardgame_ids: list[int]) -> tuple[Union[
return boardgames, missing_boardgame_ids return boardgames, missing_boardgame_ids
def get_all_boardgames(boardgame_type: SQLModel) -> Union[ def get_all_boardgames(session: Session, boardgame_type: SQLModel) -> Union[
list[boardgame_classes.BoardGame], list[boardgame_classes.BoardGameExpansion], list[boardgame_classes.BoardGame], list[boardgame_classes.BoardGameExpansion],
list[boardgame_classes.OwnedBoardGame], list[boardgame_classes.OwnedBoardGameExpansion], list[boardgame_classes.OwnedBoardGame], list[boardgame_classes.OwnedBoardGameExpansion],
list[boardgame_classes.WishlistBoardGame], list[boardgame_classes.WishlistBoardGameExpansion]]: list[boardgame_classes.WishlistBoardGame], list[boardgame_classes.WishlistBoardGameExpansion]]:
with Session(engine) as session:
statement = select(boardgame_type) statement = select(boardgame_type)
results = session.exec(statement) results = session.exec(statement)
@ -101,20 +98,16 @@ def get_all_boardgames(boardgame_type: SQLModel) -> Union[
return boardgame_list return boardgame_list
def add_play(play: play_classes.Play): def add_play(session: Session, play: play_classes.Play):
with critical_function_lock: with critical_function_lock:
with Session(engine) as session:
session.add(play) session.add(play)
session.commit() session.commit()
session.refresh(play)
def add_multiple_plays(play_list: list[play_classes.Play]): def add_multiple_plays(session: Session, play_list: list[play_classes.Play]):
with critical_function_lock: with critical_function_lock:
with Session(engine) as session:
for play in play_list: for play in play_list:
is_play_present = len(session.exec(select(play_classes.Play).where(play_classes.Play.id == play.id)).all()) != 0 is_play_present = len(session.exec(select(play_classes.Play).where(play_classes.Play.id == play.id)).all()) != 0
@ -123,8 +116,7 @@ def add_multiple_plays(play_list: list[play_classes.Play]):
session.commit() session.commit()
def get_plays() -> list[play_classes.Play]: def get_plays(session: Session, ) -> list[play_classes.Play]:
with Session(engine) as session:
statement = select(play_classes.Play) statement = select(play_classes.Play)
results = session.exec(statement) results = session.exec(statement)
@ -132,8 +124,7 @@ def get_plays() -> list[play_classes.Play]:
return play_list return play_list
def get_players_from_play(play_id: int) -> list[play_classes.PlayPlayer]: def get_players_from_play(session: Session, play_id: int) -> list[play_classes.PlayPlayer]:
with Session(engine) as session:
statement = select(play_classes.PlayPlayer).where(play_classes.PlayPlayer.play_id == play_id) statement = select(play_classes.PlayPlayer).where(play_classes.PlayPlayer.play_id == play_id)
results = session.exec(statement) results = session.exec(statement)