From caf83485c7201fe5fc273fc46332690355a688be Mon Sep 17 00:00:00 2001 From: Barbagus42 Date: Tue, 7 Nov 2023 09:32:08 +0100 Subject: [PATCH] Fix import/typing in updates.py --- data/update.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/data/update.py b/data/update.py index b7a3ea6..c2e1b3b 100644 --- a/data/update.py +++ b/data/update.py @@ -6,13 +6,10 @@ import urllib.request as request import os.path import math import contextlib as ctx +import http.client as http -from typing import Callable, Any, Optional, TYPE_CHECKING +from typing import Callable, Any, Optional -if TYPE_CHECKING: - from http.client import HTTPResponse - from zipfile import ZipFile - from sqlite3 import Connection as DBconnection # GTFS reference: # https://gtfs.org/schedule/reference/ @@ -151,7 +148,7 @@ def _get_file_names(etag: str): ) -def _fetch_dataset(response: HTTPResponse, dataset_file: str): +def _fetch_dataset(response: http.HTTPResponse, dataset_file: str): print("Fetching dataset...") content_length = int(response.getheader("Content-Length")) with open(dataset_file, "wb") as zip_out: @@ -164,7 +161,7 @@ def _fetch_dataset(response: HTTPResponse, dataset_file: str): print(f"Fetched: {zip_out.tell()}/{content_length} {progress}%") -def _check_dataset_files(zip_in: ZipFile): +def _check_dataset_files(zip_in: zipfile.ZipFile): csv_files = list(sorted(zip_in.namelist())) expected = list(sorted(csv_file for csv_file, _ in CSV_FIELDS)) @@ -179,13 +176,13 @@ def _check_csv_headers(csv_headers: list[str], fields: list[tuple[str, Any]]): assert all(a == b for a, b in zip(csv_headers, expected)), csv_headers -def _create_schema(db: DBconnection, schema_file: str): +def _create_schema(db: sqlite3.Connection, schema_file: str): with db: with open(schema_file) as sql_in: db.executescript(sql_in.read()) -def _load_data(zip_in: ZipFile, db: DBconnection): +def _load_data(zip_in: zipfile.ZipFile, db: sqlite3.Connection): for csv_file, fields in CSV_FIELDS: if fields is None: continue @@ -236,7 +233,7 @@ def _load_database(dataset_file: str, database_tempfile: str, schema_file: str): def main(): print("Checking dataset") - response: HTTPResponse = request.urlopen(GTFS_URL) + response: http.HTTPResponse = request.urlopen(GTFS_URL) if response.status != 200: raise RuntimeError("Could not fetch the dataset")