diff --git a/src/modules/db_connection.py b/src/modules/db_connection.py index 3048dcb..d07690a 100644 --- a/src/modules/db_connection.py +++ b/src/modules/db_connection.py @@ -66,19 +66,15 @@ def get_boardgame(session: Session, boardgame_id: int) -> Union[ def get_multiple_boardgames(session: Session, boardgame_type: SQLModel, boardgame_ids: list[int]) -> tuple[Union[ list[boardgame_classes.BoardGame], list[boardgame_classes.BoardGameExpansion]], list[int]]: - missing_boardgame_ids = [] - to_return_boardgames = [] + statement = select(boardgame_type).where(boardgame_type.id.in_(boardgame_ids)) + results = session.exec(statement) - for boardgame_id in boardgame_ids: - statement = select(boardgame_type).where(boardgame_type.id == boardgame_id) - results = session.exec(statement).all() - if len(results) == 0: - missing_boardgame_ids.append(boardgame_id) - else: - to_return_boardgames.append(results[0]) + boardgames = results.all() - return to_return_boardgames, missing_boardgame_ids + missing_boardgame_ids = list(filter(lambda x: x not in [boardgame.id for boardgame in boardgames], boardgame_ids)) + + return boardgames, missing_boardgame_ids def get_all_boardgames(session: Session, boardgame_type: SQLModel) -> Union[ list[boardgame_classes.BoardGame], list[boardgame_classes.BoardGameExpansion],