2014年11月26日水曜日

超簡単な機械学習によるゲイン決め

ページの一番最後にプログラムおいてきます。


PIDとかFF(フィードフォワード)とかのゲインがよくわからない

上げたら正解なのか、下げたら正解なのかよくわからん

というわけでロボットに決めてもらうことにした


やり方はすごい簡単で
1.あるパラメータで一回動作をする
2.停止
3.動作の評価をする 良ければ、最適パラメータとして保存する
4.戻す
5.停止
6.パラメータを変える
7.繰り返す

本当は同じパラメータで繰り返して平均を取ったり、
サンプルを取って、グラフにして最適値を出すのが良いんだろうけど、
めんどくさいのでなし

あと、たぶん値の更新の仕方もクイックソートのアルゴリズム(半分にしてくやつ)がいいとは思う

個人的な感想としてはすごい楽(パラメータ決める最中は動作が途中で止まらないか見てるだけ)だし、めんどくさがりな俺よりかは圧倒的に良いパラメータが得られましたとさ。


俺の手順としては、まず

FFの直進するときの速度ゲインを決めます。
このときの評価関数は、指定速度に対しての誤差(絶対値)の蓄積量をみました。
 ※冷静に考えたら、同じフレーム数だと限らないので、実際は蓄積した誤差の平均を見たほうがいいですね
あたりまえだけど、少ないほど良い。
あと、評価関数は最高速度に達しているときだけ動かしてます。

で、加速度ゲインは同じ評価関数だけど、
こっちは最初から最後まで評価関数を動かしてます。

直進に関してのPIDは
まず、Pゲインを決めて、I決めて、D決めました。
同じように、評価関数は指定速度に対しての誤差(絶対値)の蓄積をみました
こっちは最初から最後まで評価関数を動かしてます。
というか、FFの速度ゲインが特殊


次にFFの回転ゲインを調整
速度ゲイン調整して、加速度ゲインを調整
このときの評価関数は、回転誤差(絶対値)の蓄積量をみました。
直進のときと同じく少ないほど良い。

で、回転のPIDゲイン
直進と同じでPゲインを決めて、I決めて、D決めました。
評価関数は、回転誤差(絶対値)の蓄積量。


すごい楽なんで、マジでおすすめ
手で決めるのが馬鹿らしい


最後にプログラム(修正版)おいてきます。
一応、一番下に原型のプログラムおいておきます。
関数とかそれっぽい名前になってるので、察してください
一つだけいうと
MoveStraightWay()は移動が終わったらTRUEを返します。

FFを決めるプログラム

void
CheckParameterFFS()
{
  static int mode = 0;
  static int paraMode = 1;
  static float missRoot = 0;
  static float minMissRoot = 1000000;
  static float minVelo = 0;
  static float minAccel = 0;
  static int waitCounter = 0;
  static int avgCounter = 0;

  switch (mode)
    {
  case 0:
    if (MoveStraightWay(GetStraightLength(20)) == TRUE) //前進
      {
        mode++;

        if (paraMode == 0) //速度のゲイン決め
          {
            if (missRoot/avgCounter < minMissRoot)  //最適値の更新
              {
                minMissRoot = missRoot/avgCounter;
                minVelo = GetStraightVeloEff();
              }

            SetStraightVeloEff(GetStraightVeloEff() + 1.5);//パラメータの更新

            if (GetStraightVeloEff() > 40.0)//次のパラメータに
              {
                minMissRoot = 10000;
                paraMode++;
                SetStraightVeloEff(minVelo);
              }
          }
        else if (paraMode == 1) //加速度のゲイン決め
          {
            if (missRoot/avgCounter < minMissRoot)  //最適値の更新
              {
                minMissRoot = missRoot/avgCounter;
                minAccel = GetStraightaccelEff();
                usart_printf("TA %d  ", (int) (minAccel * 1000));
                usart_printf("\n\r");
              }
            SetStraightaccelEff(GetStraightaccelEff() + 1.5);//パラメータの更新

            if (GetStraightaccelEff() > 30.0)//次のパラメータに
              {
                minMissRoot = 10000;
                paraMode++;
                SetStraightaccelEff(minAccel);
                usart_printf("TA %d  ", (int) (minAccel * 1000));
                usart_printf("\n\r");
              }

          }
        else if (paraMode == 2)//終わり
          {
            usart_printf("FFVeloEff %f  ", (int) (minVelo);
            usart_printf("FFAccelEff %f  ", (int) (minAccel));
            usart_printf("\n\r");
            return;
          }
      }
    else
      {//評価関数 速度のゲイン
        if (paraMode == 0)
          {
            if (getStraigthVejyeCounter() == MAX_VELO)//最高速度時のみ評価
              {
                missRoot += abs(GetSMonoMoveLength() - GetStraightVelo());
                 avgCounter ++;
              }
          }
        if (paraMode == 1)//常に評価 加速度のゲイン
          {
            missRoot += abs(GetSMonoMoveLength() - GetStraightVelo());
            avgCounter ++;
          }
      }
    break;
  case 1: //停止
    waitCounter++;
    StillPosition(0, 0);
    if (waitCounter > 1000)
      {
        waitCounter = 0;
        mode++;
      }

    break;
  case 2: //戻る
    if (MoveStraightWay(-GetStraightLength(20)) == TRUE)
      {
        usart_printf("end back");
        usart_printf("\n\r");
        mode++;
      }
    break;
  case 3: //停止
    waitCounter++;
    StillPosition(0, 0);
    if (waitCounter > 1000)
      {
        waitCounter = 0;
        mode = 0;
        missRoot = 0;
        avgCounter  = 0;
      }

    break;

    }
}


PIDを決めるプログラム  FFとひな形は一緒


void
CheckParameter()
{
  static int mode = 0;
  static int paraMode = 0;
  static float missRoot = 0;
  static float minMissRoot = 1000000;
  static float minTKP = 0;
  static float minTKI = 0;
  static float minTKD = 0;
  static int waitCounter = 0;
  static int avgCounter = 0;

  switch (mode)
    {
  case 0:
    if (MoveStraightWay(GetStraightLength(20)) == TRUE)
      {
        mode++;
        if (paraMode == 0)
          {
            SetTKP(GetTKP() + 500);
            if (missRoot/avgCounter  < minMissRoot)
              {
                minMissRoot = missRoot/avgCounter;
                minTKP = GetTKP();
                usart_printf("TP %d  ", (int) (minTKP));
              }
            if (GetTKP() > 10.0)
              {
                minMissRoot = 10000;
                paraMode++;
                SetTKP(minTKP / 2.0);
                usart_printf("TP %d  ", (int) (minTKP));
                usart_printf("\n\r");
              }
          }
        else if (paraMode == 1)
          {
            SetTKI(GetTKI() + 300);
            if (missRoot/avgCounter< minMissRoot)
              {
                minMissRoot = missRoot/avgCounter;
                minTKI = GetTKI();
                usart_printf("TI %d  ", (int) (minTKI);
                usart_printf("\n\r");
              }
            if (GetTKI() > 5000.0)
              {
                minMissRoot = 10000;
                paraMode++;
                SetTKI(minTKI);
                usart_printf("TI %d  ", (int) (minTKI));
                usart_printf("\n\r");
              }

          }
        else if (paraMode == 2)
          {
            SetTKD(GetTKD() + 300);
            if (missRoot/avgCounter< minMissRoot)
              {
                minMissRoot = missRoot/avgCounter;
                minTKD = GetTKD();
                usart_printf("TD %d  ", (int) (minTKD));
                usart_printf("\n\r");
              }
            if (GetTKD() > 5000.0)
              {
                paraMode++;
                SetTKD(minTKD);
                usart_printf("TD %d  ", (int) (minTKD));
                usart_printf("\n\r");
              }
          }
        else if (paraMode == 3)
          {
            usart_printf("TP %d  ", (int) (minTKP);
            usart_printf("TI %d  ", (int) (minTKI));
            usart_printf("TD %d  ", (int) (minTKD));
            usart_printf("\n\r");
            return;
          }
      }
    else
      {
        missRoot += abs(GetTurnMovedLength());
        avgCounter ++;
      }
    break;
  case 1:
    waitCounter++;
    if (waitCounter > 10000)
      {
        waitCounter = 0;
        StillPosition(0, 0);
        mode++;
      }
    break;
  case 2:
    if (MoveStraightWay(-GetStraightLength(20)) == TRUE)
      mode++;
    break;
  case 3:
    waitCounter++;
    if (waitCounter > 10000)
      {
        waitCounter = 0;
        StillPosition(0, 0);
        mode = 0;
        missRoot = 0;
        avgCounter  = 0;
      }

    break;

    }
}










FFを決めるプログラム

void
CheckParameterFFS()
{
  static int mode = 0;
  static int paraMode = 1;
  static float missRoot = 0;
  static float minMissRoot = 1000000;
  static float minVelo = 0;
  static float minAccel = 0;
  static int waitCounter = 0;

  switch (mode)
    {
  case 0:
    if (MoveStraightWay(GetStraightLength(20)) == TRUE) //前進
      {
        mode++;

        if (paraMode == 0) //速度のゲイン決め
          {
            if (missRoot < minMissRoot)  //最適値の更新
              {
                minMissRoot = missRoot;
                minVelo = GetStraightVeloEff();
              }

            SetStraightVeloEff(GetStraightVeloEff() + 1.5);//パラメータの更新

            if (GetStraightVeloEff() > 40.0)//次のパラメータに
              {
                minMissRoot = 10000;
                paraMode++;
                SetStraightVeloEff(minVelo);
              }
          }
        else if (paraMode == 1) //加速度のゲイン決め
          {
            if (missRoot < minMissRoot)  //最適値の更新
              {
                minMissRoot = missRoot;
                minAccel = GetStraightaccelEff();
                usart_printf("TA %d  ", (int) (minAccel * 1000));
                usart_printf("\n\r");
              }
            SetStraightaccelEff(GetStraightaccelEff() + 1.5);//パラメータの更新

            if (GetStraightaccelEff() > 30.0)//次のパラメータに
              {
                minMissRoot = 10000;
                paraMode++;
                SetStraightaccelEff(minAccel);
                usart_printf("TA %d  ", (int) (minAccel * 1000));
                usart_printf("\n\r");
              }

          }
        else if (paraMode == 2)//終わり
          {
            usart_printf("FFVeloEff %f  ", (int) (minVelo);
            usart_printf("FFAccelEff %f  ", (int) (minAccel));
            usart_printf("\n\r");
            return;
          }
      }
    else
      {//評価関数 速度のゲイン
        if (paraMode == 0)
          {
            if (getStraigthVejyeCounter() == MAX_VELO)//最高速度時のみ評価
              {
                missRoot += abs(GetSMonoMoveLength() - GetStraightVelo());
              }
          }
        if (paraMode == 1)//常に評価 加速度のゲイン
          {
            missRoot += abs(GetSMonoMoveLength() - GetStraightVelo());
          }
      }
    break;
  case 1: //停止
    waitCounter++;
    StillPosition(0, 0);
    if (waitCounter > 1000)
      {
        waitCounter = 0;
        mode++;
      }

    break;
  case 2: //戻る
    if (MoveStraightWay(-GetStraightLength(20)) == TRUE)
      {
        usart_printf("end back");
        usart_printf("\n\r");
        mode++;
      }
    break;
  case 3: //停止
    waitCounter++;
    StillPosition(0, 0);
    if (waitCounter > 1000)
      {
        waitCounter = 0;
        mode = 0;
        missRoot = 0;
      }

    break;

    }
}


PIDを決めるプログラム  FFとひな形は一緒


void
CheckParameter()
{
  static int mode = 0;
  static int paraMode = 0;
  static float missRoot = 0;
  static float minMissRoot = 1000000;
  static float minTKP = 0;
  static float minTKI = 0;
  static float minTKD = 0;
  static int waitCounter = 0;

  switch (mode)
    {
  case 0:
    if (MoveStraightWay(GetStraightLength(20)) == TRUE)
      {
        mode++;
        if (paraMode == 0)
          {
            SetTKP(GetTKP() + 500);
            if (missRoot < minMissRoot)
              {
                minMissRoot = missRoot;
                minTKP = GetTKP();
                usart_printf("TP %d  ", (int) (minTKP));
              }
            if (GetTKP() > 10.0)
              {
                minMissRoot = 10000;
                paraMode++;
                SetTKP(minTKP / 2.0);
                usart_printf("TP %d  ", (int) (minTKP));
                usart_printf("\n\r");
              }
          }
        else if (paraMode == 1)
          {
            SetTKI(GetTKI() + 300);
            if (missRoot < minMissRoot)
              {
                minMissRoot = missRoot;
                minTKI = GetTKI();
                usart_printf("TI %d  ", (int) (minTKI);
                usart_printf("\n\r");
              }
            if (GetTKI() > 5000.0)
              {
                minMissRoot = 10000;
                paraMode++;
                SetTKI(minTKI);
                usart_printf("TI %d  ", (int) (minTKI));
                usart_printf("\n\r");
              }

          }
        else if (paraMode == 2)
          {
            SetTKD(GetTKD() + 300);
            if (missRoot < minMissRoot)
              {
                minMissRoot = missRoot;
                minTKD = GetTKD();
                usart_printf("TD %d  ", (int) (minTKD));
                usart_printf("\n\r");
              }
            if (GetTKD() > 5000.0)
              {
                paraMode++;
                SetTKD(minTKD);
                usart_printf("TD %d  ", (int) (minTKD));
                usart_printf("\n\r");
              }
          }
        else if (paraMode == 3)
          {
            usart_printf("TP %d  ", (int) (minTKP);
            usart_printf("TI %d  ", (int) (minTKI));
            usart_printf("TD %d  ", (int) (minTKD));
            usart_printf("\n\r");
            return;
          }
      }
    else
      {
        missRoot += abs(GetTurnMovedLength());
      }
    break;
  case 1:
    waitCounter++;
    if (waitCounter > 10000)
      {
        waitCounter = 0;
        StillPosition(0, 0);
        mode++;
      }
    break;
  case 2:
    if (MoveStraightWay(-GetStraightLength(20)) == TRUE)
      mode++;
    break;
  case 3:
    waitCounter++;
    if (waitCounter > 10000)
      {
        waitCounter = 0;
        StillPosition(0, 0);
        mode = 0;
        missRoot = 0;
      }

    break;

    }
}










0 件のコメント:

コメントを投稿