diff --git a/lib/nBodySim.cpp b/lib/nBodySim.cpp
index 570011f0c1b8316ee5c8d1a3f055450d1c8ed8f4..a92c3f11f98fe15f33b74c247fffcb9a136ef286 100644
--- a/lib/nBodySim.cpp
+++ b/lib/nBodySim.cpp
@@ -1,4 +1,5 @@
 #include "nBodySim.hpp"
+#include "tree.hpp"
 #include <omp.h>
 #include <cmath>
 #include <string>
@@ -15,10 +16,12 @@ nBodySim::nBodySim(std::string datafile) {
         file >> masses_[i];
     }
 
+    double maxCoord = 0;
     positions_ = new double[3*nParticles_];
     for (unsigned i=0; i < 3; ++i) {
         for (unsigned j=0; j < nParticles_; ++j) {
             file >> positions_[i+3*j];
+            if (std::abs(positions_[i+3*j]) > maxCoord) maxCoord = std::abs(positions_[i+3*j]);
         }
     }
     
@@ -45,6 +48,8 @@ nBodySim::nBodySim(std::string datafile) {
     }
     
     file.close();
+
+    tree_ = new Tree(positions_, forces_, nParticles_, 4*maxCoord, new double[3]{0, 0, 0});
 }
 
 nBodySim::~nBodySim() {
@@ -107,38 +112,9 @@ void nBodySim::calculateForces() {
     }
 }
 
-/*
-// naive implementation
-void nBodySim::calculateForces() {
-    #pragma omp parallel for
-    for (unsigned i=0; i < 3*nParticles_; ++i) {
-        forces_[i] = 0.0;
-    }
-    // loop over each pair of particles and calculate the force between them
-    #pragma omp parallel for collapse(2)
-    for (unsigned i = 0; i < nParticles_; i++) {
-        for (unsigned j = 0; j < nParticles_; j++) {
-            if (i == j) continue;
-            // calculate distance between particles
-            double dx = positions_[3*j]   - positions_[3*i];
-            double dy = positions_[3*j+1] - positions_[3*i+1];
-            double dz = positions_[3*j+2] - positions_[3*i+2];
-            double r2 = dx*dx + dy*dy + dz*dz;
-            double r = std::sqrt(r2);
-            dx /= r;
-            dy /= r;
-            dz /= r;
-            // calculate force
-            double mi = masses_[i];
-            double mj = masses_[j];
-            double s = softening_[i];
-            forces_[3*i]   += mi*mj * dx / (r2 + s*s);
-            forces_[3*i+1] += mi*mj * dy / (r2 + s*s);
-            forces_[3*i+2] += mi*mj * dz / (r2 + s*s);
-        }
-    }
+void nBodySim::treeCalculateForces() {
+    // TODO
 }
-*/
 
 double nBodySim::calculateMeanInterparticleDistance() {
     meanInterparticleDistance_ = 0.0;
@@ -184,8 +160,10 @@ void nBodySim::doTimeStep(double dt) {
     for (unsigned i=0; i < nParticles_; ++i) {
         for (unsigned j=0; j < 3; ++j) {
             // TODO
-            // update velocity
-            // update position
+            tree_->drift(dt/2);
+            tree_->update();
+            tree_->kick(dt);
+            tree_->drift(dt/2);
         }
     }
 }
diff --git a/lib/nBodySim.hpp b/lib/nBodySim.hpp
index 997cf3b3521a52cf3603a2f684c48134d89cf806..cd9fbc897d84c69bdf153aef5e7935ae7859bad2 100644
--- a/lib/nBodySim.hpp
+++ b/lib/nBodySim.hpp
@@ -1,6 +1,7 @@
 #ifndef NBODYSIM_HPP
 #define NBODYSIM_HPP
 
+#include "tree.hpp"
 #include <string>
 
 class nBodySim {
@@ -13,6 +14,8 @@ public:
     void runSimulation(double dt, unsigned nSteps);
     // loops over all pairs of particles and calculates forces, calculates current mean interparticle distance 
     void calculateForces();
+    // use treecode to calculate forces
+    void treeCalculateForces();
     // loop over all pairs of particles and calculate mean interparticle distance
     double calculateMeanInterparticleDistance();
     // loop over all force vectors and calculate mean force magnitude
@@ -32,6 +35,7 @@ public:
     double getMeanInterparticleDistance() const { return meanInterparticleDistance_; }
     double getMeanForceMagnitude() const { return meanForceMagnitude_; }
 private:
+    Tree* tree_;
     unsigned nParticles_;
     // 3d vectors are stored as a array of length 3*nParticles_ in the format x1, y1, z1, x2, y2, z2, ...
     double* masses_;
diff --git a/lib/tree.cpp b/lib/tree.cpp
index fe348c93ef470ad4a108e8997d7f0fd260396c6e..75b9257574b9efad674e75cfda6483e83d2e7f05 100644
--- a/lib/tree.cpp
+++ b/lib/tree.cpp
@@ -1,8 +1,9 @@
 #include "tree.hpp"
 #include "node.hpp"
 
-Tree::Tree(double* particles, unsigned nParticles, double size, double* center) {
+Tree::Tree(double* particles, double* forces, unsigned nParticles, double size, double* center) {
     particles_ = particles;
+    forces_ = forces;
     nParticles_ = nParticles;
     size_ = size;
 
@@ -21,4 +22,12 @@ Tree::~Tree() {
 
 void Tree::update() {
     // TODO
+}
+
+void Tree::drift(double dt) {
+    // TODO
+}
+
+void Tree::kick(double dt) {
+    // TODO
 }
\ No newline at end of file
diff --git a/lib/tree.hpp b/lib/tree.hpp
index 1da26c6905364150a3861a413799ab682023c005..9578237d61370acd0d2d12c1d2fae069bba66430 100644
--- a/lib/tree.hpp
+++ b/lib/tree.hpp
@@ -7,14 +7,19 @@ class Tree {
 public:
     // constructor
     Tree() = delete;
-    Tree(double* particles, unsigned nParticles, double size, double* center);
+    Tree(double* particles, double* forces_, unsigned nParticles, double size, double* center);
     // destructor
     ~Tree();
     // update tree: visit each node, check if it needs to be split or merged (for example if a particle has left the region)
     void update();
+    // drift tree: visit each node, update center of mass and center of mass velocity
+    void drift(double dt);
+    // kick tree: visit each node, update center of mass velocity
+    void kick(double dt);
 private:
     Node* root_;
     double* particles_;
+    double* forces_;
     unsigned nParticles_;
     double size_;
     double center_[3];