Coverage for aisdb/webdata/load_raster.py: 100%

40 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-30 04:22 +0000

1import os 

2 

3import numpy as np 

4from PIL import Image 

5 

6from aisdb.aisdb import binarysearch_vector 

7 

8Image.MAX_IMAGE_PIXELS = 650000000 # suppress DecompressionBombError warning 

9 

10 

11class _RasterFile_generic(): 

12 

13 def __enter__(self): 

14 assert hasattr(self, 'img') 

15 return self 

16 

17 def __exit__(self, exc_type, exc_val, exc_tb): 

18 ''' close raster files upon exit from context ''' 

19 self.img.close() 

20 

21 def merge_tracks(self, tracks, new_track_key: str): 

22 for track in tracks: 

23 track['dynamic'] = set(track['dynamic']).union(set([new_track_key 

24 ])) 

25 track[new_track_key] = self._track_coordinate_values(track) 

26 yield track 

27 

28 

29class RasterFile(_RasterFile_generic): 

30 

31 def _get_img_grids(self, im): 

32 if 33922 in im.tag.tagdata.keys(): 

33 # GDAL tags 

34 i, j, k, x, y, z = im.tag_v2[33922] # ModelTiepointTag 

35 dx, dy, dz = im.tag_v2[33550] # ModelPixelScaleTag 

36 lat = np.arange(y + dy, y + (dy * im.size[1]) + dy, dy)[::-1] - 90 

37 if np.sum(lat > 91): 

38 lat -= 90 

39 

40 elif 34264 in im.tag.tagdata.keys(): # pragma: no cover 

41 # NASA JPL tags 

42 dx, _, _, x, _, dy, _, y, _, _, dz, z, _, _, _, _ = im.tag_v2[ 

43 34264] # ModelTransformationTag 

44 lat = np.arange(y + dy, y + (dy * im.size[1]) + dy, dy) 

45 

46 else: 

47 raise ValueError('error: unknown metadata tag encoding') 

48 

49 lon = np.arange(x + dx, x + (dx * im.size[0]) + dx, dx) 

50 

51 return lon, lat 

52 

53 def __init__(self, imgpath): 

54 self.imgpath = imgpath 

55 assert not hasattr(self, 'img') 

56 assert os.path.isfile( 

57 self.imgpath), f'raster file {self.imgpath} not found!' 

58 self.img = Image.open(self.imgpath) 

59 self.xy = self._get_img_grids(self.img) 

60 

61 def _get_coordinate_values(self, track, rng=None): 

62 if rng is None: 

63 rng = range(len(track['time'])) 

64 idx_lons = np.array(binarysearch_vector(self.xy[0], track['lon'][rng])) 

65 idx_lats = np.array(binarysearch_vector(self.xy[1], track['lat'][rng])) 

66 return np.array(list(map( 

67 self.img.getpixel, 

68 zip(idx_lons, idx_lats), 

69 ))) 

70 

71 def _track_coordinate_values(self, track, *, rng: range = None): 

72 return self._get_coordinate_values(track, rng=rng)