【TensorFlow.js】Pose Detectionで姿勢推定を試してみる

 今回は TensorFlow.js の Pose Detection を使って姿勢推定を試していきます。コードは前回のものを使いまわしてます。

 完成しているコードは下記にあります。

Angular-sample/tensorflow at main · tsuneken5/Angular-sample
Contribute to tsuneken5/Angular-sample development by creating an account on GitHub.

姿勢推定とは

 姿勢推定とは、動画や画像から肩やひじ、ひざの位置を解析し、人間の姿勢を推定する技術のことです。

環境

  • node: 18.20.0
  • Angular CLI:17.0.0
  • TensorFlow.js:4.22.0
  • TensorFlow-model / Pose Detection:2.1.3
GitHub - tensorflow/tfjs: A WebGL accelerated JavaScript library for training and deploying ML models.
A WebGL accelerated JavaScript library for training and deploying ML models. - tensorflow/tfjs
tfjs-models/pose-detection at master · tensorflow/tfjs-models
Pretrained models for TensorFlow.js. Contribute to tensorflow/tfjs-models development by creating an account on GitHub.

パッケージのインストール

$ npm install @tensorflow/tfjs
$ npm install @tensorflow-models/pose-detection

 tsconfig.json に "skipLibCheck": true を記載して型チェックをスキップさせるのを忘れないでください。

サービスの作成

$ ng generate service service/canvas
  public async uploadImage(imgStr: string, maxWidth: number, canvas: HTMLCanvasElement): Promise<void> {
    const img = new Image();
    const ctx = canvas.getContext('2d') as CanvasRenderingContext2D;
    let scaleRaito: number = 1;

    img.src = imgStr;

    return new Promise((resolve, reject) => {
      img.onload = () => {
        let imageWidth = img.naturalWidth
        if (imageWidth > maxWidth) {
          imageWidth = maxWidth;
          scaleRaito = imageWidth / img.naturalWidth;
        }
        canvas.width = imageWidth;
        let imageHeight = img.naturalHeight * scaleRaito;
        canvas.height = imageHeight;

        resolve(ctx.drawImage(img, 0, 0, imageWidth, imageHeight));
      };
      img.onerror = (error) => {
        console.log(error);
        reject;
      }
    });
  }

  public markPoint(x: number, y: number, color: string, canvas: HTMLCanvasElement): void {
    const ctx = canvas.getContext('2d') as CanvasRenderingContext2D;

    ctx.strokeStyle = color;
    ctx.fillStyle = color;
    ctx.lineWidth = 1;

    ctx.beginPath();
    ctx.arc(x, y, 3, 0, 2 * Math.PI);
    ctx.closePath();
    ctx.fill();
  }

  public strokeLine(x1: number, y1: number, x2: number, y2: number, color: string, canvas: HTMLCanvasElement): void {
    const ctx = canvas.getContext('2d') as CanvasRenderingContext2D;

    ctx.strokeStyle = color;
    ctx.lineWidth = 1;

    ctx.beginPath();
    ctx.moveTo(x1, y1);
    ctx.lineTo(x2, y2);
    ctx.stroke();
  }

uploadImage

 base64形式の画像を canvas に描画するためのメソッドです。

markPoint

 canvas に点を描画するメソッドです。姿勢推定のキーポイントを描画するのに使用します。

strokeLine

 canvas に線を引くメソッドです。姿勢推定のキーポイントをつなぐのに使用します。

コンポーネントの作成

$ ng generate component component/page/body-segmentation
import { Component } from '@angular/core';

import * as tf from '@tensorflow/tfjs';
import * as poseDetection from '@tensorflow-models/pose-detection';

import { ImageUploaderComponent } from '../../share/image-uploader/image-uploader.component';
import { LoadingSpinnerComponent } from '../../share/loading-spinner/loading-spinner.component';
import { LoadingSpinnerService } from '../../../service/loading-spinner.service';
import { CanvasService } from '../../../service/canvas.service';

@Component({
  selector: 'app-pose-detection',
  standalone: true,
  imports: [ImageUploaderComponent, LoadingSpinnerComponent],
  templateUrl: './pose-detection.component.html',
  styleUrl: './pose-detection.component.css'
})
export class PoseDetectionComponent {
  private canvas!: HTMLCanvasElement;
  private detector!: poseDetection.PoseDetector;
  private readonly KEYPOINT_LINKS = [
    [0, 1], [0, 2], [1, 3], [2, 4], [5, 6], [5, 7], [7, 9], [6, 8], [8, 10], [11, 12], [5, 11], [6, 12], [11, 13], [13, 15], [12, 14], [14, 16]
  ];
  private readonly LINK_COLORS = [
    '#FF0000', '#FF0000', '#FF0000', '#FF0000',
    '#0000FF', '#0000FF', '#0000FF', '#0000FF', '#0000FF',
    '#800080', '#800080', '#800080',
    '#008000', '#008000', '#008000', '#008000'
  ]
  private readonly KEYPOINT_COLORS = [
    '#FF0000', '#FF69B4', '#FF69B4', '#FF4500', '#FF4500',
    '#0000FF', '#0000FF', '#1E90FF', '#1E90FF', '#00BFFF', '#00BFFF',
    '#008000', '#008000', '#32CD32', '#32CD32', '#90EE90', '#90EE90'
  ];

  constructor(
    private loadingSpinnerService: LoadingSpinnerService,
    private canvasService: CanvasService
  ) { }

  private async loadModel(): Promise<void> {
    this.loadingSpinnerService.show();
    await tf.ready();
    const model = poseDetection.SupportedModels.PoseNet;
    this.detector = await poseDetection.createDetector(model);
    this.loadingSpinnerService.hide();
  }

  private async detect(canvas: HTMLCanvasElement): Promise<poseDetection.Pose[]> {
    const predictions = await this.detector.estimatePoses(canvas);
    return predictions;
  }

  private isContains(position: { x: number, y: number }): boolean {
    return (0 < position.x) && (position.x < this.canvas.width) && (0 < position.y) && (position.y < this.canvas.height)
  }

  public async startDetected(image: string): Promise<void> {
    this.loadingSpinnerService.show();
    const parent = this.canvas.parentElement as HTMLElement;
    const width = parent.clientWidth;
    await this.canvasService.uploadImage(image, width, this.canvas);

    const results = await this.detect(this.canvas);
    console.log(results);
    for (const result of results) {
      // ライン
      for (let i = 0; i < this.KEYPOINT_LINKS.length; i++) {
        const link = this.KEYPOINT_LINKS[i];
        const color = this.LINK_COLORS[i];
        if (this.isContains(result.keypoints[link[0]]) && this.isContains(result.keypoints[link[1]])) {
          this.canvasService.strokeLine(result.keypoints[link[0]].x, result.keypoints[link[0]].y, result.keypoints[link[1]].x, result.keypoints[link[1]].y, color, this.canvas);
        }
      }
      // キーポイント
      for (let i = 0; i < result.keypoints.length; i++) {
        const keypoint = result.keypoints[i];
        if (this.isContains(keypoint)) {
          this.canvasService.markPoint(keypoint.x, keypoint.y, this.KEYPOINT_COLORS[i], this.canvas);
        }
      }
    }
    this.loadingSpinnerService.hide();
  }

  async ngOnInit() {
    this.loadModel();
  }

  ngAfterViewInit() {
    this.canvas = document.getElementById('myCanvas') as HTMLCanvasElement;
  }
}
  • KEYPOINT_LINKS ・・・ 線を引くキーポイントの組み合わせの配列です。
  • LINK_COLORS ・・・ それぞれの線のカラーの配列です。
  • KEYPOINT_COLORS ・・・ キーポイントのカラーの配列です。

loadModel

 モデルをロードします。

detect

 anvas に描画した画像を渡して姿勢推定を行います。下記のような判定結果が返ってきます。

[
    {
        "keypoints": [
            {
                "y": 265.61744589007776,
                "x": 1044.5694005350194,
                "score": 0.9723432064056396,
                "name": "nose"
            },
            {
                "y": 255.01532101167308,
                "x": 1051.313077577821,
                "score": 0.9521362781524658,
                "name": "left_eye"
            },
            {
                "y": 261.65267813715946,
                "x": 1042.7933487354085,
                "score": 0.9519529938697815,
                "name": "right_eye"
            },
            {
                "y": 254.73952760214002,
                "x": 1071.2689992704281,
                "score": 0.8407076597213745,
                "name": "left_ear"
            },
            {
                "y": 258.79438989542797,
                "x": 1029.435721668288,
                "score": 0.7747775912284851,
                "name": "right_ear"
            },
            {
                "y": 316.575731091926,
                "x": 1082.7326270671206,
                "score": 0.9951978325843811,
                "name": "left_shoulder"
            },
            {
                "y": 315.3484542193579,
                "x": 1009.7679809095331,
                "score": 0.9927864670753479,
                "name": "right_shoulder"
            },
            {
                "y": 398.0670978234435,
                "x": 1122.0836119892997,
                "score": 0.9632701277732849,
                "name": "left_elbow"
            },
            {
                "y": 386.9024045476653,
                "x": 974.6007873297666,
                "score": 0.9437273740768433,
                "name": "right_elbow"
            },
            {
                "y": 476.3234587791828,
                "x": 1129.4559368920234,
                "score": 0.9276600480079651,
                "name": "left_wrist"
            },
            {
                "y": 449.4165400048637,
                "x": 943.3600589737355,
                "score": 0.9603907465934753,
                "name": "right_wrist"
            },
            {
                "y": 472.80790977626447,
                "x": 1074.6567211819067,
                "score": 0.9976524710655212,
                "name": "left_hip"
            },
            {
                "y": 476.3426100437742,
                "x": 1025.6612506079766,
                "score": 0.9973909258842468,
                "name": "right_hip"
            },
            {
                "y": 613.98358463035,
                "x": 1075.654866853113,
                "score": 0.9897573590278625,
                "name": "left_knee"
            },
            {
                "y": 618.6457776021399,
                "x": 1035.5776538180935,
                "score": 0.9861435890197754,
                "name": "right_knee"
            },
            {
                "y": 734.1769972033072,
                "x": 1063.1519789640079,
                "score": 0.9153473973274231,
                "name": "left_ankle"
            },
            {
                "y": 742.0976410505835,
                "x": 1055.129043044747,
                "score": 0.8760051131248474,
                "name": "right_ankle"
            }
        ],
        "score": 0.9433674812316895
    }
]

 私はポイントを見やすくしたかったのでライン → ポイントの順で canvas に描画していますが、ラインを少しでも見やすくしたい人はポイント → ラインの順で描画するといいかもしれません。

実行結果

元画像

https://pixabay.com/photos/woman-nature-freedom-fashion-7841725/

実行結果

 これはうまくったほうで、精度はそこまでよくなさそうです・・・

コメント

タイトルとURLをコピーしました