summaryrefslogtreecommitdiffstats
path: root/mlplib/idxfile2np.py
blob: 37f27b463b5e1dbffc9665dfaf4c690d14a3401f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""
# -*- coding: utf-8 -*-
#
# Copyright 2021 Michael Büsch <m@bues.ch>
#
# Licensed under the Apache License version 2.0
# or the MIT license, at your option.
# SPDX-License-Identifier: Apache-2.0 OR MIT
#
"""

import gzip
import numpy as np
import struct

IDX_DTYPE_U8    = 0x08
IDX_DTYPE_S8    = 0x09
IDX_DTYPE_S16   = 0x0B
IDX_DTYPE_S32   = 0x0C
IDX_DTYPE_F32   = 0x0D
IDX_DTYPE_F64   = 0x0E

idxdtype_to_npdtype = {
    IDX_DTYPE_U8    : np.uint8,
    IDX_DTYPE_S8    : np.int8,
    IDX_DTYPE_S16   : np.int16,
    IDX_DTYPE_S32   : np.int32,
    IDX_DTYPE_F32   : np.float32,
    IDX_DTYPE_F64   : np.float64,
}

class IdxException(Exception): pass

def idxdata2np(data):
    if len(data) < 4:
        raise IdxException("IDX data: No header.")
    head, = struct.unpack_from(">I", data, 0)
    if head & 0xFFFF0000:
        raise IdxException("IDX data: Invalid header.")
    typecode = (head >> 8) & 0xFF
    dtype = idxdtype_to_npdtype.get(typecode, None)
    if dtype is None:
        raise IdxException("IDX data: Invalid type code.")
    dtype = np.dtype(dtype).newbyteorder(">") # big endian
    ndims = head & 0xFF
    if ndims == 0:
        raise IdxException("IDX data: Invalid ndims.")
    if len(data) < 4 + (4 * ndims):
        raise IdxException("IDX data: Invalid dims.")
    shape = tuple(
        struct.unpack_from(">I", data, 4 + (i * 4))[0]
        for i in range(ndims)
    )
    if len(data) < 4 + (4 * ndims) + (dtype.itemsize * np.prod(shape)):
        raise IdxException("IDX data: Invalid data section.")
    array = np.frombuffer(data, offset=(4 + (4 * ndims)), dtype=dtype)
    return array.reshape(shape, order="C")

def idxfile2np(filename):
    try:
        with gzip.open(filename, "rb") as f:
            data = f.read()
    except gzip.BadGzipFile as e:
        with open(filename, "rb") as f:
            data = f.read()
    return idxdata2np(data)

if __name__ == "__main__":
    import sys
    print(idxfile2np(sys.argv[1]))

# vim: ts=4 sw=4 expandtab
bues.ch cgit interface