import os import argparse import torch import numpy as np parser = argparse.ArgumentParser() parser.add_argument("stdc") args = parser.parse_args() os.system('rm output_stdc.txt') f = open('output_stdc.txt', 'a') m = torch.load(args.stdc) for k in m.keys(): m[k] = m[k].squeeze().cpu() f.write(k+'\n') if len(m[k].shape) == 0: f.write(str(m[k].item())+'\n') elif len(m[k].shape) == 1: f.write(",".join(m[k].numpy().astype(str))+'\n') else : for l in m[k].squeeze(): f.write(",".join(l.numpy().astype(str))+'\n') f.write('\n')