Fix CIEDE script for high bit depth
[daala.git] / tools / dump_ciede2000.py
1 #!/usr/bin/env python
2
3 from collections import deque
4 import sys
5 import numpy as np
6 from skimage import color
7 import y4m
8
9 # Assuming BT.709
10 yuv2rgb = np.array([
11     [1., 0., 1.28033], [1., -0.21482, -0.38059], [1., 2.12798, 0.]
12 ])
13
14 # Simple box filter
15 box2 = np.ones((2, 2))
16
17
18 def usage():
19     print("Usage: %s <video1> <video2>\n"
20             "    <video1> and <video2> must be YUV4MPEG files.\n\n" %  __file__);
21
22 def decode_y4m_buffer(frame):
23     W, H = frame.headers['W'], frame.headers['H']
24     Wdiv2, Hdiv2 = W // 2, H // 2
25     C, buf = frame.headers['C'], frame.buffer
26     A, Adiv2, div2 = W * H, Hdiv2 * Wdiv2, (Hdiv2, Wdiv2)
27     dtype, scale = 'uint8', 1.
28     if C.endswith('p10'):
29         dtype, scale, A, Adiv2 = 'uint16', 4., A * 2, Adiv2 * 2
30     Y = (np.ndarray((H, W), dtype, buf) - 16. * scale) / (219. * scale)
31     if C.startswith('420'):
32         Cb = (np.ndarray(div2, dtype, buf, A) - 128. * scale) / (224. * scale)
33         Cr = (np.ndarray(div2, dtype, buf, A + Adiv2) - 128. * scale) / (224. * scale)
34         YCbCr444 = np.dstack((Y, np.kron(Cb, box2), np.kron(Cr, box2)))
35     else:
36         Cb = (np.ndarray((H, W), dtype, buf, A) - 128. * scale) / (224. * scale)
37         Cr = (np.ndarray((H, W), dtype, buf, A * 2) - 128. * scale) / (224. * scale)
38         YCbCr444 = np.dstack((Y, Cb, Cr))
39     return np.dot(YCbCr444, yuv2rgb.T)
40
41
42 scores = []
43
44
45 def process_pair(ref, recons):
46     ref_lab = color.rgb2lab(decode_y4m_buffer(ref))
47     recons_lab = color.rgb2lab(decode_y4m_buffer(recons))
48     # "Color Image Quality Assessment Based on CIEDE2000"
49     # Yang Yang, Jun Ming and Nenghai Yu, 2012
50     # http://dx.doi.org/10.1155/2012/273723
51     dE = color.deltaE_ciede2000(ref_lab, recons_lab, kL=0.65, kC=1.0, kH=4.0)
52     scores.append(45. - 20. * np.log10(dE.mean()))
53     print('%08d: %2.4f' % (ref.count, scores[-1]))
54
55
56 ref_frames = deque()
57 recons_frames = deque()
58
59
60 def process_ref(frame):
61     ref_frames.append(frame)
62     if recons_frames:
63         process_pair(ref_frames.popleft(), recons_frames.popleft())
64
65
66 def process_recons(frame):
67     recons_frames.append(frame)
68     if ref_frames:
69         process_pair(ref_frames.popleft(), recons_frames.popleft())
70
71 class Reader(y4m.Reader):
72     def _frame_size(self):
73         area = self._stream_headers['W'] * self._stream_headers['H']
74         C = self._stream_headers['C']
75         if C.startswith('420'):
76             pixels = area * 3 // 2
77         elif C.startswith('444'):
78             pixels = area * 3
79         else:
80             raise Exception('Unknown pixel format: %s' % C)
81         if self._stream_headers['C'].endswith('p10'):
82             return 2 * pixels
83         return pixels
84
85 def main(args):
86     if len(args) != 3:
87         usage()
88         sys.exit(0)
89
90     OPENING = 'Opening %s...'
91     BLOCK_SIZE = 4 * 1024 * 1024
92     ref_parser = Reader(process_ref)
93     recons_parser = Reader(process_recons)
94     print(OPENING % args[1])
95     with open(args[1], 'r') as ref:
96         print(OPENING % args[2])
97         with open(args[2], 'r') as recons:
98             try:
99                 ref_buf, recons_buf = ref.buffer, recons.buffer
100             except:
101                 ref_buf, recons_buf = ref, recons
102             while True:
103                 if not ref_frames:
104                     data = ref_buf.read(BLOCK_SIZE)
105                     if not data: break
106                     ref_parser.decode(data)
107                 if not recons_frames:
108                     data = recons_buf.read(BLOCK_SIZE)
109                     if not data: break
110                     recons_parser.decode(data)
111     print('Total: %2.4f' % np.array(scores).mean())
112
113
114 if __name__ == '__main__':
115     main(sys.argv)