Coverage for aisdb/webdata/load_raster.py: 100%
40 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-30 04:22 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-30 04:22 +0000
1import os
3import numpy as np
4from PIL import Image
6from aisdb.aisdb import binarysearch_vector
8Image.MAX_IMAGE_PIXELS = 650000000 # suppress DecompressionBombError warning
11class _RasterFile_generic():
13 def __enter__(self):
14 assert hasattr(self, 'img')
15 return self
17 def __exit__(self, exc_type, exc_val, exc_tb):
18 ''' close raster files upon exit from context '''
19 self.img.close()
21 def merge_tracks(self, tracks, new_track_key: str):
22 for track in tracks:
23 track['dynamic'] = set(track['dynamic']).union(set([new_track_key
24 ]))
25 track[new_track_key] = self._track_coordinate_values(track)
26 yield track
29class RasterFile(_RasterFile_generic):
31 def _get_img_grids(self, im):
32 if 33922 in im.tag.tagdata.keys():
33 # GDAL tags
34 i, j, k, x, y, z = im.tag_v2[33922] # ModelTiepointTag
35 dx, dy, dz = im.tag_v2[33550] # ModelPixelScaleTag
36 lat = np.arange(y + dy, y + (dy * im.size[1]) + dy, dy)[::-1] - 90
37 if np.sum(lat > 91):
38 lat -= 90
40 elif 34264 in im.tag.tagdata.keys(): # pragma: no cover
41 # NASA JPL tags
42 dx, _, _, x, _, dy, _, y, _, _, dz, z, _, _, _, _ = im.tag_v2[
43 34264] # ModelTransformationTag
44 lat = np.arange(y + dy, y + (dy * im.size[1]) + dy, dy)
46 else:
47 raise ValueError('error: unknown metadata tag encoding')
49 lon = np.arange(x + dx, x + (dx * im.size[0]) + dx, dx)
51 return lon, lat
53 def __init__(self, imgpath):
54 self.imgpath = imgpath
55 assert not hasattr(self, 'img')
56 assert os.path.isfile(
57 self.imgpath), f'raster file {self.imgpath} not found!'
58 self.img = Image.open(self.imgpath)
59 self.xy = self._get_img_grids(self.img)
61 def _get_coordinate_values(self, track, rng=None):
62 if rng is None:
63 rng = range(len(track['time']))
64 idx_lons = np.array(binarysearch_vector(self.xy[0], track['lon'][rng]))
65 idx_lats = np.array(binarysearch_vector(self.xy[1], track['lat'][rng]))
66 return np.array(list(map(
67 self.img.getpixel,
68 zip(idx_lons, idx_lats),
69 )))
71 def _track_coordinate_values(self, track, *, rng: range = None):
72 return self._get_coordinate_values(track, rng=rng)