From 1bc8733cf91c38068b9d276e6f3863f7c7867abe Mon Sep 17 00:00:00 2001 From: Yarne Coppens Date: Thu, 15 Aug 2024 10:29:54 +0200 Subject: [PATCH] Made DB thread safe --- src/classes/statistic_classes.py | 2 +- src/main.py | 51 +++++----- src/modules/bgg_connection.py | 36 ++++++- src/modules/data_connection.py | 77 ++++++++------- src/modules/db_connection.py | 159 ++++++++++++++++--------------- tests/test_main.py | 2 +- 6 files changed, 186 insertions(+), 141 deletions(-) diff --git a/src/classes/statistic_classes.py b/src/classes/statistic_classes.py index a09a713..eb9229e 100644 --- a/src/classes/statistic_classes.py +++ b/src/classes/statistic_classes.py @@ -10,7 +10,7 @@ class StatisticBase(BaseModel): class NumberStatistic(StatisticBase): result: float -class GameOrderStatistic(StatisticBase): +class GamesStatistic(StatisticBase): result: list[Union[ boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion, boardgame_classes.OwnedBoardGame, boardgame_classes.OwnedBoardGameExpansion, diff --git a/src/main.py b/src/main.py index 38b150b..5569f64 100644 --- a/src/main.py +++ b/src/main.py @@ -11,14 +11,14 @@ from src.classes import boardgame_classes, play_classes, statistic_classes from src.modules import data_connection from src.filters import boardgame_filters, play_filters -def get_db_session(): - yield data_connection.get_db_session() - @asynccontextmanager async def lifespan(app: FastAPI): # Startup data_connection.delete_database() data_connection.create_db_and_tables() + # data_connection.get_user_collection() + # data_connection.get_user_owned_collection() + data_connection.get_user_wishlist_collection() yield # Shutdown @@ -56,13 +56,13 @@ def read_root(): return {"Hello": "World"} @app.get("/boardgame", response_model=Union[boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion]) -def get_boardgame_by_id(id: int, session: Session = Depends(get_db_session)): - requested_boardgame: Union[boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion] = data_connection.get_boardgame(session, id) +def get_boardgame_by_id(id: int): + requested_boardgame: Union[boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion] = data_connection.get_boardgame(id) return requested_boardgame @app.get("/owned", response_model=list[Union[boardgame_classes.OwnedBoardGame, boardgame_classes.OwnedBoardGameExpansion]]) -def get_owned_collection(query: ExpansionFilteringParams = Depends(), session: Session = Depends(get_db_session)): - to_return_boardgames = data_connection.get_user_owned_collection(session) +def get_owned_collection(query: ExpansionFilteringParams = Depends()): + to_return_boardgames = data_connection.get_user_owned_collection() if query.filter_expansions_out: to_return_boardgames = boardgame_filters.filter_expansions_out(to_return_boardgames) @@ -74,9 +74,9 @@ def get_owned_collection(query: ExpansionFilteringParams = Depends(), session: S @app.get("/wishlist", response_model=list[Union[boardgame_classes.WishlistBoardGame, boardgame_classes.WishlistBoardGameExpansion]]) -def get_wishlist_collection(priority: int = 0, query: ExpansionFilteringParams = Depends(), session: Session = Depends(get_db_session)): +def get_wishlist_collection(priority: int = 0, query: ExpansionFilteringParams = Depends()): - to_return_boardgames = data_connection.get_user_wishlist_collection(session, priority) + to_return_boardgames = data_connection.get_user_wishlist_collection(priority) if query.filter_expansions_out: to_return_boardgames = boardgame_filters.filter_expansions_out(to_return_boardgames) @@ -88,9 +88,9 @@ def get_wishlist_collection(priority: int = 0, query: ExpansionFilteringParams = @app.get("/plays", response_model=list[play_classes.PlayPublicWithPlayers]) -def get_plays(query: ExpansionFilteringParams = Depends(), boardgame_id: int = -1, session: Session = Depends(get_db_session)): +def get_plays(query: ExpansionFilteringParams = Depends(), boardgame_id: int = -1): - requested_plays = data_connection.get_plays(session) + requested_plays = data_connection.get_plays() if query.filter_expansions_out: requested_plays = play_filters.filter_expansions_out(requested_plays) @@ -105,16 +105,15 @@ def get_plays(query: ExpansionFilteringParams = Depends(), boardgame_id: int = - @app.get('/players', response_model=list[play_classes.PlayPlayerPublic]) -def get_players_from_play(play_id: int, session: Session = Depends(get_db_session)): - requested_players = data_connection.get_players_from_play(session, play_id) +def get_players_from_play(play_id: int): + requested_players = data_connection.get_players_from_play(play_id) return requested_players - @app.get('/statistics/amount_of_games', response_model=statistic_classes.NumberStatistic) -def get_amount_of_games(query: ExpansionFilteringParams = Depends(), session: Session = Depends(get_db_session)): +def get_amount_of_games(query: ExpansionFilteringParams = Depends()): - owned_collection = data_connection.get_user_owned_collection(session) + owned_collection = data_connection.get_user_owned_collection() if query.filter_expansions_out: owned_collection = boardgame_filters.filter_expansions_out(owned_collection) @@ -131,14 +130,14 @@ def get_amount_of_games(query: ExpansionFilteringParams = Depends(), session: Se return statistic_to_return @app.get('/statistics/amount_of_games_over_time', response_model=statistic_classes.TimeLineStatistic) -def get_amount_of_games_over_time(day_step: int = 1, filter_expansions_out: bool = False, only_expansions: bool = False, session: Session = Depends(get_db_session)): +def get_amount_of_games_over_time(day_step: int = 1, filter_expansions_out: bool = False, only_expansions: bool = False): def daterange(start_date: date, end_date: date, day_step): days = int((end_date - start_date).days) for n in range(0, days, day_step): yield start_date + timedelta(n) - games_in_owned_collection = data_connection.get_user_owned_collection(session) + games_in_owned_collection = data_connection.get_user_owned_collection() games_in_owned_collection.sort(key=lambda x: x.acquisition_date) start_date = games_in_owned_collection[0].acquisition_date @@ -165,8 +164,8 @@ def get_amount_of_games_over_time(day_step: int = 1, filter_expansions_out: bool return statistic_to_return @app.get('/statistics/games_played_per_year', response_model=statistic_classes.TimeLineStatistic) -def get_amount_of_games_played_per_year(query: ExpansionFilteringParams = Depends(), session: Session = Depends(get_db_session)): - all_plays = data_connection.get_plays(session) +def get_amount_of_games_played_per_year(query: ExpansionFilteringParams = Depends()): + all_plays = data_connection.get_plays() all_plays.sort(key= lambda x: x.play_date) @@ -204,10 +203,10 @@ def get_amount_of_games_played_per_year(query: ExpansionFilteringParams = Depend return statistic_to_return -@app.get('/statistics/most_expensive_games', response_model=statistic_classes.GameOrderStatistic) -def get_most_expensive_game(query: ExpansionFilteringParams = Depends(), top_amount: int = 10, session: Session = Depends(get_db_session)): +@app.get('/statistics/most_expensive_games', response_model=statistic_classes.GamesStatistic) +def get_most_expensive_game(query: ExpansionFilteringParams = Depends(), top_amount: int = 10): - most_expensive_games: list[Union[boardgame_classes.OwnedBoardGame, boardgame_classes.OwnedBoardGameExpansion]] = data_connection.get_user_owned_collection(session) + most_expensive_games: list[Union[boardgame_classes.OwnedBoardGame, boardgame_classes.OwnedBoardGameExpansion]] = data_connection.get_user_owned_collection() if query.filter_expansions_out: most_expensive_games = boardgame_filters.filter_expansions_out(most_expensive_games) @@ -224,6 +223,10 @@ def get_most_expensive_game(query: ExpansionFilteringParams = Depends(), top_amo "result":most_expensive_games } - statistic_to_return = statistic_classes.GameOrderStatistic(**statistic_dict) + statistic_to_return = statistic_classes.GamesStatistic(**statistic_dict) return statistic_to_return + +@app.get('/statistics/shelf_of_shame', response_model=statistic_classes.GamesStatistic) +def get_shelf_of_shame(query: ExpansionFilteringParams = Depends()): + pass \ No newline at end of file diff --git a/src/modules/bgg_connection.py b/src/modules/bgg_connection.py index df59f62..04b0568 100644 --- a/src/modules/bgg_connection.py +++ b/src/modules/bgg_connection.py @@ -7,6 +7,7 @@ import time import math from typing import Union import html +from tqdm import tqdm from src.classes import boardgame_classes, play_classes from src.modules import auth_manager @@ -16,7 +17,10 @@ authenticated_session: requests.Session = requests.Session() def url_to_xml_object(url: HttpUrl) -> ET.Element: - r = authenticated_session.get(url) + try: + r = authenticated_session.get(url) + except: + r = authenticated_session.get(url) while r.status_code == 202 or r.status_code == 429: if r.status_code == 202: @@ -50,7 +54,11 @@ def get_multiple_boardgames(boardgame_ids: list[int]) -> list[boardgame_classes. #Boardgamegeek only allows chunks of 20 boardgames at a time boardgame_ids_divided = list(divide_list_in_chunks(boardgame_ids)) - for boardgame_id_list_size_20 in boardgame_ids_divided: + for boardgame_id_list_size_20 in tqdm( + boardgame_ids_divided, + desc="Getting boardgames from BGG", + unit="requests"): + boardgame_id_list_commas: str = ','.join(map(str,boardgame_id_list_size_20)) url : str = "https://boardgamegeek.com/xmlapi2/thing?id={}&stats=true".format(boardgame_id_list_commas) boardgames_xml_object : ET.Element = url_to_xml_object(url) @@ -242,6 +250,10 @@ def get_boardgames_from_collection_url(collection_url: str, boardgame_type: boar for boardgame_item in collection_xml: boardgame_extra = boardgame_extras[current_index] match boardgame_type: + case boardgame_classes.BoardgameType.BOARDGAME: + boardgame = boardgame_extra + case boardgame_classes.BoardgameType.BOARDGAMEEXPANSION: + boardgame = boardgame_extra case boardgame_classes.BoardgameType.OWNEDBOARDGAME: boardgame = convert_collection_xml_to_owned_boardgame(boardgame_extra, boardgame_item) case boardgame_classes.BoardgameType.OWNEDBOARDGAMEEXPANSION: @@ -256,7 +268,18 @@ def get_boardgames_from_collection_url(collection_url: str, boardgame_type: boar return collection_list -def get_user_owned_collection() -> Union[list[boardgame_classes.OwnedBoardGame], list[boardgame_classes.OwnedBoardGameExpansion]]: +def get_user_collection() -> list[Union[boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion]]: + url_no_expansions = 'https://boardgamegeek.com/xmlapi2/collection?username={}&stats=1&excludesubtype=boardgameexpansion&showprivate=1&version=1'.format(auth_manager.username) + url_only_expansions = 'https://boardgamegeek.com/xmlapi2/collection?username={}&stats=1&subtype=boardgameexpansion&showprivate=1&version=1'.format(auth_manager.username) + + boardgames = get_boardgames_from_collection_url(url_no_expansions, boardgame_classes.BoardgameType.BOARDGAME) + boardgame_expansions = get_boardgames_from_collection_url(url_only_expansions, boardgame_classes.BoardgameType.BOARDGAMEEXPANSION) + + boardgames += boardgame_expansions + + return boardgames + +def get_user_owned_collection() -> list[Union[boardgame_classes.OwnedBoardGame, boardgame_classes.OwnedBoardGameExpansion]]: url_no_expansions = 'https://boardgamegeek.com/xmlapi2/collection?username={}&own=1&stats=1&excludesubtype=boardgameexpansion&showprivate=1&version=1'.format(auth_manager.username) url_only_expansions = 'https://boardgamegeek.com/xmlapi2/collection?username={}&own=1&stats=1&subtype=boardgameexpansion&showprivate=1&version=1'.format(auth_manager.username) @@ -290,8 +313,11 @@ def get_plays() -> list[play_classes.Play]: amount_of_pages_needed = math.ceil(amount_of_plays_total/float(definitions.BGG_PLAY_PAGE_SIZE)) all_plays : list[play_classes.Play] = [] - - for page in range(amount_of_pages_needed): + + for page in tqdm( + range(amount_of_pages_needed), + desc="Getting plays from BGG", + unit="requests"): url = 'https://boardgamegeek.com/xmlapi2/plays?username={}&page={}'.format(auth_manager.username, page + 1) plays_page_xml_object = url_to_xml_object(url) for play_xml in plays_page_xml_object: diff --git a/src/modules/data_connection.py b/src/modules/data_connection.py index 6ec28d8..9d145ed 100644 --- a/src/modules/data_connection.py +++ b/src/modules/data_connection.py @@ -1,16 +1,15 @@ from typing import Union -from sqlmodel import SQLModel, Session + +from threading import Lock +critical_function_lock = Lock() from src.modules import bgg_connection, db_connection from src.classes import boardgame_classes, play_classes -def get_db_session(): - return db_connection.get_session() - -def get_boardgame(session: Session, boardgame_id: int) -> Union[boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion]: +def get_boardgame(boardgame_id: int) -> Union[boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion]: #Will check if it already exists in db, then it will get it from there - boardgame_in_db = db_connection.get_boardgame(session, boardgame_id=boardgame_id) + boardgame_in_db = db_connection.get_boardgame(boardgame_id=boardgame_id) to_return_boardgame = None @@ -18,54 +17,64 @@ def get_boardgame(session: Session, boardgame_id: int) -> Union[boardgame_classe to_return_boardgame = boardgame_in_db else: to_return_boardgame = bgg_connection.get_boardgame(boardgame_id) - db_connection.add_boardgame(session, to_return_boardgame) - to_return_boardgame = db_connection.get_boardgame(session, boardgame_id) + db_connection.add_boardgame(to_return_boardgame) + to_return_boardgame = db_connection.get_boardgame(boardgame_id) return to_return_boardgame -def get_multiple_boardgames(session: Session, boardgame_ids: list[int]) -> list[Union[boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion]]: - - boardgames_in_db, boardgame_ids_missing = db_connection.get_multiple_boardgames(session, boardgame_ids=boardgame_ids) +def get_multiple_boardgames(boardgame_ids: list[int]) -> list[Union[boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion]]: + boardgames_in_db, boardgame_ids_missing = db_connection.get_multiple_boardgames(boardgame_ids=boardgame_ids) if len(boardgame_ids_missing) != 0: missing_boardgames = bgg_connection.get_multiple_boardgames(boardgame_ids_missing) - db_connection.add_multiple_boardgames(session, missing_boardgames) + db_connection.add_multiple_boardgames(missing_boardgames) - boardgames_in_db, boardgame_ids_missing = db_connection.get_multiple_boardgames(session, boardgame_ids=boardgame_ids) + boardgames_in_db, boardgame_ids_missing = db_connection.get_multiple_boardgames(boardgame_ids=boardgame_ids) return boardgames_in_db +def get_user_collection() -> list[Union[boardgame_classes.BoardGame, boardgame_classes.OwnedBoardGame]]: + boardgames_from_db: list[boardgame_classes.BoardGame] = db_connection.get_all_boardgames(boardgame_classes.BoardGame) + boardgame_expansions_from_db: list[boardgame_classes.BoardGameExpansion] = db_connection.get_all_boardgames(boardgame_classes.BoardGameExpansion) + if len(boardgames_from_db) == 0 and len(boardgame_expansions_from_db) == 0: + boardgames = bgg_connection.get_user_collection() + db_connection.add_multiple_boardgames(boardgames) -def get_user_owned_collection(session: Session) -> list[Union[boardgame_classes.OwnedBoardGame, boardgame_classes.OwnedBoardGameExpansion]]: + boardgames_from_db: list[boardgame_classes.BoardGame] = db_connection.get_all_boardgames(boardgame_classes.BoardGame) + boardgame_expansions_from_db: list[boardgame_classes.BoardGameExpansion] = db_connection.get_all_boardgames(boardgame_classes.BoardGameExpansion) + + return boardgames_from_db + boardgame_expansions_from_db + +def get_user_owned_collection() -> list[Union[boardgame_classes.OwnedBoardGame, boardgame_classes.OwnedBoardGameExpansion]]: - owned_boardgames_from_db: list[boardgame_classes.OwnedBoardGame] = db_connection.get_all_boardgames(session, boardgame_classes.OwnedBoardGame) - owned_boardgame_expanions_from_db: list[boardgame_classes.OwnedBoardGameExpansion] = db_connection.get_all_boardgames(session, boardgame_classes.OwnedBoardGameExpansion) + owned_boardgames_from_db: list[boardgame_classes.OwnedBoardGame] = db_connection.get_all_boardgames(boardgame_classes.OwnedBoardGame) + owned_boardgame_expanions_from_db: list[boardgame_classes.OwnedBoardGameExpansion] = db_connection.get_all_boardgames(boardgame_classes.OwnedBoardGameExpansion) if len(owned_boardgames_from_db) == 0 and len(owned_boardgame_expanions_from_db) == 0: owned_boardgames = bgg_connection.get_user_owned_collection() - db_connection.add_multiple_boardgames(session, owned_boardgames) + db_connection.add_multiple_boardgames(owned_boardgames) - owned_boardgames_from_db: list[boardgame_classes.OwnedBoardGame] = db_connection.get_all_boardgames(session, boardgame_classes.OwnedBoardGame) - owned_boardgame_expanions_from_db: list[boardgame_classes.OwnedBoardGameExpansion] = db_connection.get_all_boardgames(session, boardgame_classes.OwnedBoardGameExpansion) + owned_boardgames_from_db: list[boardgame_classes.OwnedBoardGame] = db_connection.get_all_boardgames(boardgame_classes.OwnedBoardGame) + owned_boardgame_expanions_from_db: list[boardgame_classes.OwnedBoardGameExpansion] = db_connection.get_all_boardgames(boardgame_classes.OwnedBoardGameExpansion) return owned_boardgames_from_db + owned_boardgame_expanions_from_db -def get_user_wishlist_collection(session: Session, wishlist_priority: int = 0) -> Union[list[boardgame_classes.WishlistBoardGame], list[boardgame_classes.WishlistBoardGameExpansion]]: +def get_user_wishlist_collection(wishlist_priority: int = 0) -> Union[list[boardgame_classes.WishlistBoardGame], list[boardgame_classes.WishlistBoardGameExpansion]]: - wishlisted_boardgames_from_db = db_connection.get_all_boardgames(session, boardgame_classes.WishlistBoardGame) - wishlisted_boardgame_expansions_from_db = db_connection.get_all_boardgames(session, boardgame_classes.WishlistBoardGameExpansion) + wishlisted_boardgames_from_db = db_connection.get_all_boardgames(boardgame_classes.WishlistBoardGame) + wishlisted_boardgame_expansions_from_db = db_connection.get_all_boardgames(boardgame_classes.WishlistBoardGameExpansion) if len(wishlisted_boardgames_from_db) == 0 and len(wishlisted_boardgame_expansions_from_db) == 0: wishlisted_boardgames = bgg_connection.get_user_wishlist_collection() - db_connection.add_multiple_boardgames(session, wishlisted_boardgames) + db_connection.add_multiple_boardgames(wishlisted_boardgames) - wishlisted_boardgames_from_db = db_connection.get_all_boardgames(session, boardgame_classes.WishlistBoardGame) - wishlisted_boardgame_expansions_from_db = db_connection.get_all_boardgames(session, boardgame_classes.WishlistBoardGameExpansion) + wishlisted_boardgames_from_db = db_connection.get_all_boardgames(boardgame_classes.WishlistBoardGame) + wishlisted_boardgame_expansions_from_db = db_connection.get_all_boardgames(boardgame_classes.WishlistBoardGameExpansion) to_return_boardgames = wishlisted_boardgames_from_db + wishlisted_boardgame_expansions_from_db @@ -75,16 +84,16 @@ def get_user_wishlist_collection(session: Session, wishlist_priority: int = 0) - return to_return_boardgames -def get_plays(session: Session) -> list[play_classes.Play]: +def get_plays() -> list[play_classes.Play]: - plays_from_db = db_connection.get_plays(session) + plays_from_db = db_connection.get_plays() if len(plays_from_db) == 0: all_plays = bgg_connection.get_plays() - db_connection.add_multiple_plays(session, all_plays) + db_connection.add_multiple_plays(all_plays) - plays_from_db = db_connection.get_plays(session) + plays_from_db = db_connection.get_plays() #Making sure all played board games are in table 'boardgames' #list + set to remove duplicates @@ -97,19 +106,19 @@ def get_plays(session: Session) -> list[play_classes.Play]: assert len(list(filter(lambda x: x == None, played_boardgame_ids))) == 0, plays_from_db - get_multiple_boardgames(session, played_boardgame_ids + played_expansion_ids) + get_multiple_boardgames(played_boardgame_ids + played_expansion_ids) return plays_from_db -def get_players_from_play(session: Session, play_id: int) -> list[play_classes.PlayPlayer]: - players_from_db = db_connection.get_players_from_play(session, play_id) +def get_players_from_play(play_id: int) -> list[play_classes.PlayPlayer]: + players_from_db = db_connection.get_players_from_play(play_id) if len(players_from_db) == 0: all_plays = bgg_connection.get_plays() - db_connection.add_multiple_plays(session, all_plays) + db_connection.add_multiple_plays(all_plays) - players_from_db = db_connection.get_players_from_play(session, play_id) + players_from_db = db_connection.get_players_from_play(play_id) return players_from_db diff --git a/src/modules/db_connection.py b/src/modules/db_connection.py index dbb11f5..b7b7c0f 100644 --- a/src/modules/db_connection.py +++ b/src/modules/db_connection.py @@ -1,7 +1,9 @@ from sqlmodel import create_engine, SQLModel, Session, select from src.config import definitions from typing import Union -from sqlalchemy.orm import sessionmaker, scoped_session +from threading import Lock + +critical_function_lock = Lock() from src.classes import boardgame_classes, play_classes @@ -9,130 +11,135 @@ sqlite_url = definitions.SQLITE_URL connect_args = {"check_same_thread": False} -engine = create_engine(sqlite_url, echo=False, connect_args=connect_args) +engine = create_engine(sqlite_url, echo=True, connect_args=connect_args) -db_session = scoped_session(sessionmaker(autocommit=False, - autoflush=True, - bind=engine, - class_=Session)) - -def get_session(): - with db_session() as session: - return session - -def add_boardgame(session: Session, boardgame: Union[ +def add_boardgame(boardgame: Union[ boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion, boardgame_classes.OwnedBoardGame, boardgame_classes.OwnedBoardGameExpansion, boardgame_classes.WishlistBoardGame, boardgame_classes.WishlistBoardGameExpansion]): - is_boardgame_present = len(session.exec( - select(boardgame.__class__).where(boardgame.__class__.id == boardgame.id) - ).all()) != 0 + with critical_function_lock: + with Session(engine) as session: + is_boardgame_present = len(session.exec( + select(boardgame.__class__).where(boardgame.__class__.id == boardgame.id) + ).all()) != 0 - if not is_boardgame_present: - session.add(boardgame) + if not is_boardgame_present: + session.add(boardgame) - session.commit() + session.commit() -def add_multiple_boardgames(session: Session, boardgame_list: list[Union[ +def add_multiple_boardgames(boardgame_list: list[Union[ boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion, boardgame_classes.OwnedBoardGame, boardgame_classes.OwnedBoardGameExpansion, boardgame_classes.WishlistBoardGame, boardgame_classes.WishlistBoardGameExpansion]]): - for boardgame in boardgame_list: - is_boardgame_present = len(session.exec( - select(boardgame.__class__).where(boardgame.__class__.id == boardgame.id) - ).all()) != 0 + with critical_function_lock: + with Session(engine) as session: + for boardgame in boardgame_list: + is_boardgame_present = len(session.exec( + select(boardgame.__class__).where(boardgame.__class__.id == boardgame.id) + ).all()) != 0 - if not is_boardgame_present: - session.add(boardgame) + if not is_boardgame_present: + session.add(boardgame) - session.commit() + session.commit() -def get_boardgame(session: Session, boardgame_id: int) -> Union[ +def get_boardgame(boardgame_id: int) -> Union[ boardgame_classes.BoardGame, boardgame_classes.BoardGameExpansion]: - statement = select(boardgame_classes.BoardGame).where(boardgame_classes.BoardGame.id == boardgame_id) - - base_boardgames = session.exec(statement).all() + with Session(engine) as session: + statement = select(boardgame_classes.BoardGame).where(boardgame_classes.BoardGame.id == boardgame_id) + + base_boardgames = session.exec(statement).all() - statement = select(boardgame_classes.BoardGameExpansion).where(boardgame_classes.BoardGameExpansion.id == boardgame_id) + statement = select(boardgame_classes.BoardGameExpansion).where(boardgame_classes.BoardGameExpansion.id == boardgame_id) - expansion_boardgames = session.exec(statement).all() + expansion_boardgames = session.exec(statement).all() - returned_boardgames = base_boardgames + expansion_boardgames + returned_boardgames = base_boardgames + expansion_boardgames - if len(returned_boardgames) == 0: - boardgame = None - else: - boardgame = returned_boardgames[0] + if len(returned_boardgames) == 0: + boardgame = None + else: + boardgame = returned_boardgames[0] - return boardgame + return boardgame -def get_multiple_boardgames(session: Session, boardgame_ids: list[int]) -> tuple[Union[ +def get_multiple_boardgames(boardgame_ids: list[int]) -> tuple[Union[ list[boardgame_classes.BoardGame], list[boardgame_classes.BoardGameExpansion]], list[int]]: - statement = select(boardgame_classes.BoardGame).where(boardgame_classes.BoardGame.id.in_(boardgame_ids)) - results = session.exec(statement) + with Session(engine) as session: + statement = select(boardgame_classes.BoardGame).where(boardgame_classes.BoardGame.id.in_(boardgame_ids)) + results = session.exec(statement) - boardgames = results.all() + boardgames = results.all() - statement = select(boardgame_classes.BoardGameExpansion).where(boardgame_classes.BoardGameExpansion.id.in_(boardgame_ids)) - results = session.exec(statement) + statement = select(boardgame_classes.BoardGameExpansion).where(boardgame_classes.BoardGameExpansion.id.in_(boardgame_ids)) + results = session.exec(statement) - expansions = results.all() + expansions = results.all() - boardgames += expansions + boardgames += expansions - missing_boardgame_ids = list(filter(lambda x: x not in [boardgame.id for boardgame in boardgames], boardgame_ids)) + missing_boardgame_ids = list(filter(lambda x: x not in [boardgame.id for boardgame in boardgames], boardgame_ids)) - return boardgames, missing_boardgame_ids + return boardgames, missing_boardgame_ids -def get_all_boardgames(session: Session, boardgame_type: SQLModel) -> Union[ +def get_all_boardgames(boardgame_type: SQLModel) -> Union[ list[boardgame_classes.BoardGame], list[boardgame_classes.BoardGameExpansion], list[boardgame_classes.OwnedBoardGame], list[boardgame_classes.OwnedBoardGameExpansion], list[boardgame_classes.WishlistBoardGame], list[boardgame_classes.WishlistBoardGameExpansion]]: - statement = select(boardgame_type) + with Session(engine) as session: + statement = select(boardgame_type) + + results = session.exec(statement) - results = session.exec(statement) - - boardgame_list = results.all() + boardgame_list = results.all() - return boardgame_list + return boardgame_list -def add_play(session: Session, play: play_classes.Play): +def add_play(play: play_classes.Play): - session.add(play) + with critical_function_lock: - session.commit() - session.refresh(play) - -def add_multiple_plays(session: Session, play_list: list[play_classes.Play]): - - for play in play_list: - is_play_present = len(session.exec(select(play_classes.Play).where(play_classes.Play.id == play.id)).all()) != 0 - - if not is_play_present: + with Session(engine) as session: session.add(play) - session.commit() + session.commit() + session.refresh(play) -def get_plays(session: Session) -> list[play_classes.Play]: - statement = select(play_classes.Play) - results = session.exec(statement) +def add_multiple_plays(play_list: list[play_classes.Play]): - play_list = results.all() + with critical_function_lock: + with Session(engine) as session: + for play in play_list: + is_play_present = len(session.exec(select(play_classes.Play).where(play_classes.Play.id == play.id)).all()) != 0 - return play_list + if not is_play_present: + session.add(play) -def get_players_from_play(session: Session, play_id: int) -> list[play_classes.PlayPlayer]: - statement = select(play_classes.PlayPlayer).where(play_classes.PlayPlayer.play_id == play_id) - results = session.exec(statement) + session.commit() - player_list = results.all() +def get_plays() -> list[play_classes.Play]: + with Session(engine) as session: + statement = select(play_classes.Play) + results = session.exec(statement) - return player_list + play_list = results.all() + + return play_list + +def get_players_from_play(play_id: int) -> list[play_classes.PlayPlayer]: + with Session(engine) as session: + statement = select(play_classes.PlayPlayer).where(play_classes.PlayPlayer.play_id == play_id) + results = session.exec(statement) + + player_list = results.all() + + return player_list def delete_database(): SQLModel.metadata.drop_all(engine) diff --git a/tests/test_main.py b/tests/test_main.py index c1a7402..f6de31e 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -122,7 +122,7 @@ def test_retrieve_game_order_statistic(): response = client.get("/statistics/most_expensive_games") assert response.status_code == 200 - returned_statistic = statistic_classes.GameOrderStatistic.model_validate(response.json()) + returned_statistic = statistic_classes.GamesStatistic.model_validate(response.json()) default_statistic_test(returned_statistic) assert type(returned_statistic.result) == list