diff --git a/cpp_dnn/dnn.cpp b/cpp_dnn/dnn.cpp
index 8421725bc6ab17e3e3d60db4040ea56447c0f93c..bd396fb3bf788423a566a574ac1bf925f506d374 100644
--- a/cpp_dnn/dnn.cpp
+++ b/cpp_dnn/dnn.cpp
@@ -58,8 +58,14 @@ namespace cpp_dnn {
 
         
 
-            // void sgd_step(const scalar lrate){
-            // }
+            virtual void sgd_step(const scalar lrate){
+            }
+
+
+            MatrixXI predict (const MatrixX& Ypred) {
+                return arma::index_max(Ypred,0) ;
+            }
+
             virtual ~Node() {
                 // std::cout<<"deleted_"<<name<<" \n";
             }
@@ -85,9 +91,9 @@ namespace cpp_dnn {
                 //  (in size, out size)
                 _W0.zeros(_n, 1);
                 
-                _W<< 1.24737338 << 0.28295388 << 0.69207227<<arma::endr
-                  << 1.58455078 << 1.32056292 <<-0.69103982<<arma::endr;
-                // _W.randn(m,n);//change with tf.randomNormal(shape1,0,1.0 * m ** (-.5),'float32',0);
+                // _W<< 1.24737338 << 0.28295388 << 0.69207227<<arma::endr
+                //   << 1.58455078 << 1.32056292 <<-0.69103982<<arma::endr;
+                _W.randn(m,n);//change with tf.randomNormal(shape1,0,1.0 * m ** (-.5),'float32',0);
                 // _name="Linear("+m+","+n+")";
 
                 
@@ -164,8 +170,8 @@ namespace cpp_dnn {
                 return this->_W.size()+this->_W0.size();
             }
 
-            // void sgd_step(const scalar lrate){
-            // }
+            void sgd_step(const scalar lrate){
+            }
         // private:
     };
 
@@ -197,8 +203,8 @@ namespace cpp_dnn {
                 return this->_W.size()+this->_W0.size();
             }
 
-            // void sgd_step(const scalar lrate){
-            // }
+            void sgd_step(const scalar lrate){
+            }
         // private:
     };
 
@@ -218,8 +224,8 @@ namespace cpp_dnn {
                 return dLdA % (1. - dLdA); //check wrong probably
             }
 
-            // void sgd_step(const scalar lrate){
-            // }
+            void sgd_step(const scalar lrate){
+            }
         // private:
     };
 
@@ -244,14 +250,15 @@ namespace cpp_dnn {
             }
             MatrixXI predict (const MatrixX& Ypred) {
                 return arma::index_max(Ypred,0) ;
-            };
+            }
 
             unsigned getWeightsNum(){
                 return this->_W.size()+this->_W0.size();
             }
+            
 
-            // void sgd_step(const scalar lrate){
-            // }
+            void sgd_step(const scalar lrate){
+            }
         // private:
     };
 
@@ -277,8 +284,9 @@ namespace cpp_dnn {
             virtual MatrixX backward(){ 
                 return _Ypred-_Y;
             }
-            // virtual void sgd_step(const scalar lrate){
-            // }
+
+            virtual void sgd_step(const scalar lrate){
+            }
             
         // private:
     };
@@ -332,15 +340,18 @@ namespace cpp_dnn {
 
             MatrixX backward(const MatrixX& delta){ 
                 MatrixX res=MatrixX(delta);
-                for(unsigned i=_modules.size()-1;i>=0.;i-- ){
+                for(int i=_modules.size()-1;i>=0;i=i-1 ){
                     res = _modules[i]->backward(res);
                 }
             }
-            // void sgd_step(const scalar lrate){
-            //     for(unsigned i=0;i< _modules.size();i++ ){
-            //         _modules[i]->sgd_step(lrate);
-            //     }
-            // }
+
+            void sgd_step(const scalar lrate){
+                for(unsigned i=0;i< _modules.size();i++ ){
+                    _modules[i]->sgd_step(lrate);
+                }
+                
+            }
+
             MatrixXI step (const MatrixX& Xt,const MatrixX&Yt) { //add argument for weights
                 MatrixX Ypred = this->forward(Xt);
                 // std::cout << Ypred<< '\n';
@@ -437,16 +448,33 @@ namespace cpp_dnn {
             
             
             // void sgd(){
-            void sgd(const MatrixX& X,const MatrixX& Y, const unsigned iters,const scalar lrate){
-                unsigned D= X.n_cols;
-                unsigned N= X.n_rows;
+            void sgd(const MatrixX& X,const MatrixX& Y, const int iters,const scalar lrate){
+                // unsigned D= X.n_cols;
+                // unsigned N= X.n_rows;
+
+                unsigned D= X.n_rows;
+                unsigned N= X.n_cols;
+                unsigned O= Y.n_rows;
                 // var D=X.shape[0];
                 // var N=X.shape[1];
                 int sum_loss=0;
                 // var sum_loss=tf.tensor([0]);
-                for(unsigned i=0;i<iters;i++){
+                
+                for(int i=0;i<iters;i++){
                     unsigned j=(unsigned)(rand() % N);
+                    MatrixX Xt= X.submat( 0, j, D-1, j );
+                    MatrixX Yt= Y.submat( 0, j, O-1, j );
+                    auto Ypred = this->forward(Xt);
+                    sum_loss+= arma::accu( _loss->forward(Ypred, Yt) );
+
+                    auto err = _loss->backward();
+
+                    this->backward(err);
+
+                    this->sgd_step(lrate); 
+                    
                 }
+                std::cout<<"Loss: "<< sum_loss<<std::endl;
                 
                 // for(var i=0; i < iters; i++){
                 //     var j= getRandomInt(N);
diff --git a/cpp_dnn/main.cpp b/cpp_dnn/main.cpp
index 57f6bd88151cd5e74e2981a638d9719e2b47209c..badd39b62d5a573d2037f7af31e1b8dc94856608 100644
--- a/cpp_dnn/main.cpp
+++ b/cpp_dnn/main.cpp
@@ -53,10 +53,7 @@ double fitfun(double const *x, int N,MatrixX X,MatrixX Y,Sequential seq) { /* fu
 //--------------------------------------------------------------------------------------------------
 int main(int const, char const**) {
 
-    srand(0);
-    cout<<getRandom()<<endl;
-
-    // //Create network
+  // //Create network
     // std::vector<Node*> network;
 
     // Linear l1(2, 12);
@@ -88,112 +85,112 @@ int main(int const, char const**) {
 
     // Sequential seq(network, &l);
 
-    Linear linear_1(2,3);
-    scalar lrate = 0.005;
-    // arma::Mat<double> a=arma::randu(2,2);
+    // Linear linear_1(2,3);
+    // scalar lrate = 0.005;
+    // // arma::Mat<double> a=arma::randu(2,2);
 
     
-    arma::Mat<scalar> X;
-    arma::Mat<scalar> Y;
-    X<< 2<<  3<< 9<< 12<<arma::endr
-     << 5<<  2<< 6<<  5<<arma::endr;
-    Y<<0<< 1<< 0<< 1<<arma::endr
-     <<1<< 0<< 1<< 0<<arma::endr;
-    // arma::Mat<scalar> dL_dz1;
-    // dL_dz1<< 1.69467553e-09<<-1.33530535e-06<<  0.00000000e+00<<-0.00000000e+00<<arma::endr
-    //       <<-5.24547376e-07<< 5.82459519e-04<< -3.84805202e-10<< 1.47943038e-09<<arma::endr
-    //       <<-3.47063705e-02<< 2.55611604e-01<< -1.83538094e-02<< 1.11838432e-04<<arma::endr;
-
-    // std::cout << linear_1.forward(X) << '\n';
-    // std::cout << linear_1.backward(dL_dz1) << '\n';
-    // linear_1.sgd_step(lrate);
-    // std::cout << linear_1._W << '\n';
-    // std::cout << linear_1._W0 << '\n';
-
-
-
-    std::cout << arma::index_max(Y,0) << '\n';
-    std::vector<Node*> network;
-    Linear l1(2, 3);
-    ReLU t;
-    Linear l2(3, 2);
-    SoftMax s;
-    NLL l;
-    network.push_back( &l1);
-    network.push_back( &t);
-    network.push_back( &l2);
-    network.push_back( &s);
+    // arma::Mat<scalar> X;
+    // arma::Mat<scalar> Y;
+    // X<< 2<<  3<< 9<< 12<<arma::endr
+    //  << 5<<  2<< 6<<  5<<arma::endr;
+    // Y<<0<< 1<< 0<< 1<<arma::endr
+    //  <<1<< 0<< 1<< 0<<arma::endr;
+    // // arma::Mat<scalar> dL_dz1;
+    // // dL_dz1<< 1.69467553e-09<<-1.33530535e-06<<  0.00000000e+00<<-0.00000000e+00<<arma::endr
+    // //       <<-5.24547376e-07<< 5.82459519e-04<< -3.84805202e-10<< 1.47943038e-09<<arma::endr
+    // //       <<-3.47063705e-02<< 2.55611604e-01<< -1.83538094e-02<< 1.11838432e-04<<arma::endr;
+
+    // // std::cout << linear_1.forward(X) << '\n';
+    // // std::cout << linear_1.backward(dL_dz1) << '\n';
+    // // linear_1.sgd_step(lrate);
+    // // std::cout << linear_1._W << '\n';
+    // // std::cout << linear_1._W0 << '\n';
+
+
+
+    // std::cout << arma::index_max(Y,0) << '\n';
+    // std::vector<Node*> network;
+    // Linear l1(2, 3);
+    // ReLU t;
+    // Linear l2(3, 2);
+    // SoftMax s;
+    // NLL l;
+    // network.push_back( &l1);
+    // network.push_back( &t);
+    // network.push_back( &l2);
+    // network.push_back( &s);
 
-    arma::Mat<scalar> z_1 = l1.forward(X);
+    // arma::Mat<scalar> z_1 = l1.forward(X);
 
-    std::cout << z_1 << '\n';
+    // std::cout << z_1 << '\n';
 
-    arma::Mat<scalar> a_1 = t.forward(z_1);
+    // arma::Mat<scalar> a_1 = t.forward(z_1);
 
-    std::cout << a_1 << '\n';
+    // std::cout << a_1 << '\n';
 
-    arma::Mat<scalar> z_2;
+    // arma::Mat<scalar> z_2;
 
-    // z_2 <<0.40837833<<0.53900088<< 0.56956001<< 0.57209377<<arma::endr
-    //      << -0.66368766<< 0.65353931<< 0.96361427<< 0.98919526<<arma::endr;
+    // // z_2 <<0.40837833<<0.53900088<< 0.56956001<< 0.57209377<<arma::endr
+    // //      << -0.66368766<< 0.65353931<< 0.96361427<< 0.98919526<<arma::endr;
 
 
-    z_2<<5.28714248<< 3.64078533<< 10.92235599<< 12.36410102<<arma::endr
-       <<  0.78906625<< 0.80620366<< 2.41861097<<4.44170662<<arma::endr;
+    // z_2<<5.28714248<< 3.64078533<< 10.92235599<< 12.36410102<<arma::endr
+    //    <<  0.78906625<< 0.80620366<< 2.41861097<<4.44170662<<arma::endr;
     
-    // arma::Mat<scalar> z_2 = l2.forward(a_1);
+    // // arma::Mat<scalar> z_2 = l2.forward(a_1);
 
-    // std::cout << z_2 << '\n';
+    // // std::cout << z_2 << '\n';
 
     
-    arma::Mat<scalar> a_2 = s.forward(z_2);
+    // arma::Mat<scalar> a_2 = s.forward(z_2);
 
-    std::cout << a_2 << '\n';
+    // std::cout << a_2 << '\n';
     
     
 
-    arma::Mat<scalar> loss = l.forward(a_2, Y);
+    // arma::Mat<scalar> loss = l.forward(a_2, Y);
 
-    std::cout << loss<< '\n';
+    // std::cout << loss<< '\n';
 
-    arma::Mat<scalar> dloss = l.backward();
+    // arma::Mat<scalar> dloss = l.backward();
 
-    std::cout << dloss<< '\n';
+    // std::cout << dloss<< '\n';
 
-    arma::Mat<scalar>  dL_dz2 = s.backward(dloss);
+    // arma::Mat<scalar>  dL_dz2 = s.backward(dloss);
 
-    std::cout << dL_dz2<< '\n';
+    // std::cout << dL_dz2<< '\n';
   
-    arma::Mat<scalar> dL_da1;
+    // arma::Mat<scalar> dL_da1;
 
-    // dL_da1<<0.47375374<< -0.3361494<< 0.25611147<<-0.38332583<<arma::endr
-    // <<-0.2210031<< 0.15681155<< -0.11947437<< 0.17881905<<arma::endr
-    // <<-0.56355604<< 0.39986813<< -0.30465863<< 0.45598708<<arma::endr;
+    // // dL_da1<<0.47375374<< -0.3361494<< 0.25611147<<-0.38332583<<arma::endr
+    // // <<-0.2210031<< 0.15681155<< -0.11947437<< 0.17881905<<arma::endr
+    // // <<-0.56355604<< 0.39986813<< -0.30465863<< 0.45598708<<arma::endr;
 
-    dL_da1<< 6.28919807e-01<< -3.52832568e-02<<  6.35791049e-01<< -2.30458563e-04<<arma::endr
-    <<-2.93387075e-01<<  1.64594141e-02<< -2.96592466e-01<<  1.07507449e-04<<arma::endr
-    <<-7.48134578e-01<<  4.19713676e-02<< -7.56308297e-01<<  2.74143091e-04<<arma::endr;
+    // dL_da1<< 6.28919807e-01<< -3.52832568e-02<<  6.35791049e-01<< -2.30458563e-04<<arma::endr
+    // <<-2.93387075e-01<<  1.64594141e-02<< -2.96592466e-01<<  1.07507449e-04<<arma::endr
+    // <<-7.48134578e-01<<  4.19713676e-02<< -7.56308297e-01<<  2.74143091e-04<<arma::endr;
 
-    std::cout << dL_da1<< '\n';
+    // std::cout << dL_da1<< '\n';
 
-    // dL_da1 = linear_2.backward(dL_dz2)
+    // // dL_da1 = linear_2.backward(dL_dz2)
     
 
-    arma::Mat<scalar> dL_dz1 = t.backward(dL_da1);
+    // arma::Mat<scalar> dL_dz1 = t.backward(dL_da1);
 
-    std::cout << dL_dz1<< '\n';
+    // std::cout << dL_dz1<< '\n';
 
-    arma::Mat<scalar> dL_dX = l1.backward(dL_dz1);
+    // arma::Mat<scalar> dL_dX = l1.backward(dL_dz1);
 
-    std::cout << dL_dX<< '\n';
+    // std::cout << dL_dX<< '\n';
 
-    l1.sgd_step(lrate);
-    std::cout << l1._W << '\n';
-    std::cout << l1._W0 << '\n';
+    // l1.sgd_step(lrate);
+    // std::cout << l1._W << '\n';
+    // std::cout << l1._W0 << '\n';
 
-    l1.sgd_step(lrate);
-    std::cout << l1._W << '\n';
-    std::cout << l1._W0 << '\n';
+    // l1.sgd_step(lrate);
+    // std::cout << l1._W << '\n';
+    // std::cout << l1._W0 << '\n';
 
     // dL_dX = linear_1.backward(dL_dz1)
     // unit_test('dL_dX', test_values['dL_dX'], dL_dX)
@@ -222,8 +219,49 @@ int main(int const, char const**) {
     // Sequential seq(network, &l);
     // std::cout <<seq.getWeightsNum()<<'\n';
     // std::cout <<seq.step(X,Y)<< '\n';
+    arma::Mat<scalar> X;
+    arma::Mat<scalar> Y;
+
+    X <<-0.23390341<< 1.18151883<< -2.46493986<< 1.55322202<< 1.27621763<< 2.39710997<< -1.3440304<< -0.46903436<< -0.64673502<< -1.44029872<<
+        -1.37537243<< 1.05994811<< -0.93311512<< 1.02735575<< -0.84138778<<-2.22585412<< -0.42591102<< 1.03561105<< 0.91125595<< -2.26550369<<arma::endr <<
+        -0.92254932<< -1.1030963<< -2.41956036<< -1.15509002<< -1.04805327<< 0.08717325<< 0.8184725<< -0.75171045<< 0.60664705<< 0.80410947<<
+        -0.11600488<< 1.03747218<< -0.67210575<< 0.99944446<< -0.65559838<<-0.40744784<< -0.58367642<< 1.0597278<< -0.95991874<< -1.41720255<<arma::endr;
 
+    Y <<1.<< 1.<< 0.<< 1.<< 1.<< 1.<< 0.<< 0.<< 0.<< 0.<< 0.<< 1.<< 1.<< 1.<< 0.<< 0.<< 0.<< 1.<< 1.<< 0.<<arma::endr
+      <<0.<< 0.<< 1.<< 0.<< 0.<< 0.<< 1.<< 1.<< 1.<< 1.<< 1.<< 0.<< 0.<< 0.<< 1.<< 1.<< 1.<< 0.<< 0.<< 1.<<arma::endr;
 
+    std::vector<Node*> network;
+    Linear l1(2, 10);
+    ReLU r1;
+    Linear l2(10, 10);
+    ReLU r2;
+    Linear l3(10, 2);
+    SoftMax s;
+    NLL l;
+    network.push_back( &l1);
+    network.push_back( &r1);
+    network.push_back( &l2);
+    network.push_back( &r2);
+    network.push_back( &l3);
+    network.push_back( &s);
+
+    Sequential nn(network, &l);
+
+    // Sequential([Linear(2, 10), ReLU(), Linear(10, 10), ReLU(), Linear(10,2), SoftMax()], NLL())
+
+    
+    unsigned it=10000;
+    unsigned D= X.n_rows;
+    unsigned N= X.n_cols;
+    unsigned O= Y.n_rows;
+
+    scalar lrate = 0.005;
+    // Modifies the weights and biases
+    nn.sgd(X, Y, it, lrate);
+
+    // Draw it...
+    
+    cout<<nn.predict(nn.forward(X))<<endl;
 
 
     cout<<"hereeee "<<endl;