diff --git a/eval/step_configs/config_femnist_100.ini b/eval/step_configs/config_femnist_100.ini
index aafe7d6fcc0eab9a4a97162daae35f23177871ff..becf623fe8e80f9ce8eeb267b9bb79685a76684e 100644
--- a/eval/step_configs/config_femnist_100.ini
+++ b/eval/step_configs/config_femnist_100.ini
@@ -26,7 +26,7 @@ loss_class = CrossEntropyLoss
 [COMMUNICATION]
 comm_package = decentralizepy.communication.TCP
 comm_class = TCP
-addresses_filepath = ip_addr_7Machines.json
+addresses_filepath = ip_addr_6Machines.json
 
 [SHARING]
 sharing_package = decentralizepy.sharing.Sharing
diff --git a/setup.cfg b/setup.cfg
index fa12457f1ab0415570f8bf5e2558a0dd6183c275..3faa1f36fc490a44ab218e6ce2c38aa78c9b9016 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -40,7 +40,6 @@ install_requires =
         zmq
         jsonlines
         pillow
-        pickle
         smallworld
         localconfig
 include_package_data = True
diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py
index 173a54f7d3143c2a728b772bb722ea13570c9c61..41d061b6e60fece2f2927a09d1eb1f4744300907 100644
--- a/src/decentralizepy/node/Node.py
+++ b/src/decentralizepy/node/Node.py
@@ -24,7 +24,7 @@ class Node:
         plt.title(title)
         plt.savefig(filename)
 
-    def __init__(
+    def instantiate(
         self,
         rank: int,
         machine_id: int,
@@ -38,7 +38,7 @@ class Node:
         *args
     ):
         """
-        Constructor
+        Construct objects
         Parameters
         ----------
         rank : int
@@ -95,30 +95,32 @@ class Node:
 
         dataset_configs = config["DATASET"]
         dataset_module = importlib.import_module(dataset_configs["dataset_package"])
-        dataset_class = getattr(dataset_module, dataset_configs["dataset_class"])
-        dataset_params = utils.remove_keys(
+        self.dataset_class = getattr(dataset_module, dataset_configs["dataset_class"])
+        self.dataset_params = utils.remove_keys(
             dataset_configs, ["dataset_package", "dataset_class", "model_class"]
         )
-        self.dataset = dataset_class(
-            self.rank, self.machine_id, self.mapping, **dataset_params
+        self.dataset = self.dataset_class(
+            self.rank, self.machine_id, self.mapping, **self.dataset_params
         )
 
         logging.info("Dataset instantiation complete.")
 
-        model_class = getattr(dataset_module, dataset_configs["model_class"])
-        self.model = model_class()
+        self.model_class = getattr(dataset_module, dataset_configs["model_class"])
+        self.model = self.model_class()
 
         optimizer_configs = config["OPTIMIZER_PARAMS"]
         optimizer_module = importlib.import_module(
             optimizer_configs["optimizer_package"]
         )
-        optimizer_class = getattr(
+        self.optimizer_class = getattr(
             optimizer_module, optimizer_configs["optimizer_class"]
         )
-        optimizer_params = utils.remove_keys(
+        self.optimizer_params = utils.remove_keys(
             optimizer_configs, ["optimizer_package", "optimizer_class"]
         )
-        self.optimizer = optimizer_class(self.model.parameters(), **optimizer_params)
+        self.optimizer = self.optimizer_class(
+            self.model.parameters(), **self.optimizer_params
+        )
 
         train_configs = config["TRAIN_PARAMS"]
         train_module = importlib.import_module(train_configs["training_package"])
@@ -172,16 +174,24 @@ class Node:
             **sharing_params
         )
 
+        self.iterations = iterations
+        self.test_after = test_after
+        self.log_dir = log_dir
+
+    def run(self):
+        """
+        Start the decentralized learning
+        """
         self.testset = self.dataset.get_testset()
-        rounds_to_test = test_after
+        rounds_to_test = self.test_after
 
-        for iteration in range(iterations):
+        for iteration in range(self.iterations):
             logging.info("Starting training iteration: %d", iteration)
             self.trainer.train(self.dataset)
 
             self.sharing.step()
-            self.optimizer = optimizer_class(
-                self.model.parameters(), **optimizer_params
+            self.optimizer = self.optimizer_class(
+                self.model.parameters(), **self.optimizer_params
             )  # Reset optimizer state
             self.trainer.reset_optimizer(self.optimizer)
 
@@ -209,14 +219,14 @@ class Node:
                 "train_loss",
                 "Training Loss",
                 "Communication Rounds",
-                os.path.join(log_dir, "{}_train_loss.png".format(self.rank)),
+                os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)),
             )
 
             rounds_to_test -= 1
 
             if self.dataset.__testing__ and rounds_to_test == 0:
                 logging.info("Evaluating on test set.")
-                rounds_to_test = test_after
+                rounds_to_test = self.test_after
                 ta, tl = self.dataset.test(self.model, self.loss)
                 results_dict["test_acc"][iteration + 1] = ta
                 results_dict["test_loss"][iteration + 1] = tl
@@ -226,20 +236,83 @@ class Node:
                     "test_loss",
                     "Testing Loss",
                     "Communication Rounds",
-                    os.path.join(log_dir, "{}_test_loss.png".format(self.rank)),
+                    os.path.join(self.log_dir, "{}_test_loss.png".format(self.rank)),
                 )
                 self.save_plot(
                     results_dict["test_acc"],
                     "test_acc",
                     "Testing Accuracy",
                     "Communication Rounds",
-                    os.path.join(log_dir, "{}_test_acc.png".format(self.rank)),
+                    os.path.join(self.log_dir, "{}_test_acc.png".format(self.rank)),
                 )
 
             with open(
-                os.path.join(log_dir, "{}_results.json".format(self.rank)), "w"
+                os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w"
             ) as of:
                 json.dump(results_dict, of)
 
         self.communication.disconnect_neighbors()
         logging.info("All neighbors disconnected. Process complete!")
+
+    def __init__(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        log_level=logging.INFO,
+        test_after=5,
+        *args
+    ):
+        """
+        Constructor
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        n_procs_local : int
+            Number of processes on current machine
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations. Must contain the following:
+            [DATASET]
+                dataset_package
+                dataset_class
+                model_class
+            [OPTIMIZER_PARAMS]
+                optimizer_package
+                optimizer_class
+            [TRAIN_PARAMS]
+                training_package = decentralizepy.training.Training
+                training_class = Training
+                epochs_per_round = 25
+                batch_size = 64
+        log_dir : str
+            Logging directory
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        args : optional
+            Other arguments
+        """
+        self.instantiate(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            config,
+            iterations,
+            log_dir,
+            log_level,
+            test_after,
+            *args
+        )
+
+        self.run()
diff --git a/src/decentralizepy/training/Training.py b/src/decentralizepy/training/Training.py
index 6129b34b90e6ae9abbf7e94d7d93dab668cf5b2e..a7fb3b14d9651723a3a1958d90a2177641ec751a 100644
--- a/src/decentralizepy/training/Training.py
+++ b/src/decentralizepy/training/Training.py
@@ -11,7 +11,14 @@ class Training:
     """
 
     def __init__(
-        self, model, optimizer, loss, epochs_per_round="", batch_size="", shuffle=""
+        self,
+        model,
+        optimizer,
+        loss,
+        rounds="",
+        full_epochs="",
+        batch_size="",
+        shuffle="",
     ):
         """
         Constructor
@@ -23,8 +30,10 @@ class Training:
             Optimizer to learn parameters
         loss : function
             Loss function
-        epochs_per_round : int, optional
-            Number of epochs per training call
+        rounds : int, optional
+            Number of steps/epochs per training call
+        full_epochs: bool, optional
+            True if 1 round = 1 epoch. False if 1 round = 1 minibatch
         batch_size : int, optional
             Number of items to learn over, in one batch
         shuffle : bool
@@ -33,7 +42,8 @@ class Training:
         self.model = model
         self.optimizer = optimizer
         self.loss = loss
-        self.epochs_per_round = utils.conditional_value(epochs_per_round, "", int(1))
+        self.rounds = utils.conditional_value(rounds, "", int(1))
+        self.full_epochs = utils.conditional_value(full_epochs, "", False)
         self.batch_size = utils.conditional_value(batch_size, "", int(1))
         self.shuffle = utils.conditional_value(shuffle, "", False)
 
@@ -68,25 +78,61 @@ class Training:
         logging.info("Loss after iteration: {}".format(loss))
         return loss
 
-    def train(self, dataset):
+    def trainstep(self, data, target):
+        """
+        One training step on a minibatch.
+        Parameters
+        ----------
+        data : any
+            Data item
+        target : any
+            Label
+        Returns
+        -------
+        int
+            Loss Value for the step
+        """
+        self.model.zero_grad()
+        output = self.model(data)
+        loss_val = self.loss(output, target)
+        loss_val.backward()
+        self.optimizer.step()
+        return loss_val.item()
+
+    def train_full(self, trainset):
         """
         One training iteration, goes through the entire dataset
         Parameters
         ----------
+        trainset : torch.utils.data.Dataloader
+            The training dataset.
+        """
+        for epoch in range(self.epochs_per_round):
+            epoch_loss = 0.0
+            count = 0
+            for data, target in trainset:
+                epoch_loss += self.trainstep(data, target)
+                count += 1
+            logging.info("Epoch: {} loss: {}".format(epoch, epoch_loss / count))
+
+    def train(self, dataset):
+        """
+        One training iteration
+        Parameters
+        ----------
         dataset : decentralizepy.datasets.Dataset
             The training dataset. Should implement get_trainset(batch_size, shuffle)
         """
         trainset = dataset.get_trainset(self.batch_size, self.shuffle)
 
-        for epoch in range(self.epochs_per_round):
-            epoch_loss = 0.0
+        if self.full_epochs:
+            self.train_full(trainset)
+        else:
+            iter_loss = 0.0
             count = 0
             for data, target in trainset:
-                self.model.zero_grad()
-                output = self.model(data)
-                loss_val = self.loss(output, target)
-                epoch_loss += loss_val.item()
-                loss_val.backward()
-                self.optimizer.step()
+                iter_loss += self.trainstep(data, target)
                 count += 1
-            logging.info("Epoch: {} loss: {}".format(epoch, epoch_loss / count))
+                logging.info("Round: {} loss: {}".format(count, iter_loss / count))
+                if count >= self.rounds:
+                    break