Skip to content
Snippets Groups Projects
testing.py 1.67 KiB
Newer Older
Rishi Sharma's avatar
Rishi Sharma committed
import logging
Rishi Sharma's avatar
Rishi Sharma committed
from pathlib import Path
Rishi Sharma's avatar
Rishi Sharma committed
from shutil import copy
Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed
from localconfig import LocalConfig
Rishi Sharma's avatar
Rishi Sharma committed
from torch import multiprocessing as mp

Rishi Sharma's avatar
Rishi Sharma committed
from decentralizepy import utils
Rishi Sharma's avatar
Rishi Sharma committed
from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Linear import Linear
from decentralizepy.node.Node import Node

Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed
def read_ini(file_path):
    config = LocalConfig(file_path)
    for section in config:
        print("Section: ", section)
        for key, value in config.items(section):
            print((key, value))
Rishi Sharma's avatar
Rishi Sharma committed
    print(dict(config.items("DATASET")))
Rishi Sharma's avatar
Rishi Sharma committed
    return config
Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed
if __name__ == "__main__":
Rishi Sharma's avatar
Rishi Sharma committed
    args = utils.get_args()
Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed
    Path(args.log_dir).mkdir(parents=True, exist_ok=True)
    log_level = {
        "INFO": logging.INFO,
        "DEBUG": logging.DEBUG,
        "WARNING": logging.WARNING,
        "ERROR": logging.ERROR,
        "CRITICAL": logging.CRITICAL,
    }
Rishi Sharma's avatar
Rishi Sharma committed

    config = read_ini(args.config_file)
    my_config = dict()
    for section in config:
        my_config[section] = dict(config.items(section))

Rishi Sharma's avatar
Rishi Sharma committed
    copy(args.config_file, args.log_dir)
    copy(args.graph_file, args.log_dir)
    utils.write_args(args, args.log_dir)

Rishi Sharma's avatar
Rishi Sharma committed
    g = Graph()
Rishi Sharma's avatar
Rishi Sharma committed
    g.read_graph_from_file(args.graph_file, args.graph_type)
Rishi Sharma's avatar
Rishi Sharma committed
    n_machines = args.machines
    procs_per_machine = args.procs_per_machine
Rishi Sharma's avatar
Rishi Sharma committed
    l = Linear(n_machines, procs_per_machine)
Rishi Sharma's avatar
Rishi Sharma committed
    m_id = args.machine_id
Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed
    mp.spawn(
        fn=Node,
        nprocs=procs_per_machine,
        args=[
            m_id,
            l,
            g,
            my_config,
            args.iterations,
            args.log_dir,
            log_level[args.log_level],
            args.test_after,
Rishi Sharma's avatar
Rishi Sharma committed
    )