我正在运行一个模拟,我需要在每个周期记录一些小的numpy数组。我目前的解决方案是加载,写入然后保存如下:
existing_data = np.load(“existing_record.npy”)…
我认为一种解决方案是使用内存映射文件 numpy.memmap 。代码可以在下面找到。该文档包含了解代码的重要信息。
import numpy as np from os.path import getsize from time import time filename = "data.bin" # Datatype used for memmap dtype = np.int32 # Create memmap for the first time (w+). Arbitrary shape. Probably good to try and guess the correct size. mm = np.memmap(filename, dtype=dtype, mode='w+', shape=(1, )) print("File has {} bytes".format(getsize(filename))) N = 20 num_data_per_loop = 10**7 # Main loop to append data for i in range(N): # will extend the file because mode='r+' starttime = time() mm = np.memmap(filename, dtype=dtype, mode='r+', offset=np.dtype(dtype).itemsize*num_data_per_loop*i, shape=(num_data_per_loop, )) mm[:] = np.arange(start=num_data_per_loop*i, stop=num_data_per_loop*(i+1)) mm.flush() endtime = time() print("{:3d}/{:3d} ({:6.4f} sec): File has {} bytes".format(i, N, endtime-starttime, getsize(filename))) A = np.array(np.memmap(filename, dtype=dtype, mode='r')) if np.array_equal(A, np.arange(num_data_per_loop*N, dtype=dtype)): print("Correct")
我得到的输出是:
File has 4 bytes 0/ 20 (0.2167 sec): File has 40000000 bytes 1/ 20 (0.2200 sec): File has 80000000 bytes 2/ 20 (0.2131 sec): File has 120000000 bytes 3/ 20 (0.2180 sec): File has 160000000 bytes 4/ 20 (0.2215 sec): File has 200000000 bytes 5/ 20 (0.2141 sec): File has 240000000 bytes 6/ 20 (0.2187 sec): File has 280000000 bytes 7/ 20 (0.2138 sec): File has 320000000 bytes 8/ 20 (0.2137 sec): File has 360000000 bytes 9/ 20 (0.2227 sec): File has 400000000 bytes 10/ 20 (0.2168 sec): File has 440000000 bytes 11/ 20 (0.2141 sec): File has 480000000 bytes 12/ 20 (0.2150 sec): File has 520000000 bytes 13/ 20 (0.2144 sec): File has 560000000 bytes 14/ 20 (0.2190 sec): File has 600000000 bytes 15/ 20 (0.2186 sec): File has 640000000 bytes 16/ 20 (0.2210 sec): File has 680000000 bytes 17/ 20 (0.2146 sec): File has 720000000 bytes 18/ 20 (0.2178 sec): File has 760000000 bytes 19/ 20 (0.2182 sec): File has 800000000 bytes Correct
由于用于memmap的偏移,时间在迭代中大致恒定。此外,所需的RAM量(除了在末尾加载整个memmap以进行检查)是不变的。
我希望这能解决您的性能问题
亲切的问候
卢卡斯
编辑1:看来海报已经解决了他自己的问题。我把这个答案作为替代方案。
我找到了一个使用h5py库的好工作解决方案。性能要好得多,因为没有读取数据,我减少了nump数组追加操作的数量。一个简短的例子:
with h5py.File("logfile_name", "a") as f: ds = f.create_dataset("weights", shape=(3,2,100000), maxshape=(3, 2, None)) ds[:,:,cycle_num] = weight_matrix
我不确定numpy样式切片是否意味着矩阵被复制但是有一个 write_direct(source, source_sel=None, dest_sel=None) 功能,以避免这种情况发生,这可能对较大的矩阵有用。
write_direct(source, source_sel=None, dest_sel=None)