diff --git a/src/get_images_sdss/get_images_sdss.py b/src/get_images_sdss/get_images_sdss.py index 71db6e53601e45a598dc77120f1572ecee035071..0ffd8118abf203324402068bf0697b39e1953eab 100644 --- a/src/get_images_sdss/get_images_sdss.py +++ b/src/get_images_sdss/get_images_sdss.py @@ -37,8 +37,12 @@ def get_args(): help="Path to save image files.\n") parser.add_argument('-i', '--inputfile', type=dir_file, default='galaxies.csv', help="Galaxy database file in *.csv format.\n") + parser.add_argument('-o', '--outputfile', type=str, default=None, + help="""Output galaxy database file in *.csv format containing + the path to the downloaded images. If not specified, the + paths will be written in the input CSV file.""") parser.add_argument('-s', '--size', type=int, default=512, - help="Size of the downloaded images.\n") + help="Size of the edges (in pixels) of the downloaded images.\n") return parser.parse_args() @@ -46,10 +50,11 @@ def main(): # Params args = get_args() path = Path(args.path) - file = Path(args.inputfile) + inputfile = Path(args.inputfile) size = args.size + outputfile = args.outputfile # Read catalogue - galaxies = pd.read_csv(file) + galaxies = pd.read_csv(inputfile) # Download images n = len(galaxies) @@ -63,11 +68,15 @@ def main(): fname = f"{path}/img_{group}_{id}.jpeg" download_sdss(fname, row['ra'], row['dec'], size) + galaxies.loc[index, 'filename'] = f"img_{group}_{id}.jpeg" if index % 10 == 0: print(f"Remaining {n-index}") except KeyError as e: print(f"KeyError: {e}") + + file = Path(outputfile) if outputfile else inputfile + galaxies.to_csv(file, index=None) print("Done!") if __name__ == '__main__':