21
11

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

GitHub Code Copilotでデータサイエンティストが遊んでみた

Last updated at Posted at 2021-11-03

Code Copilotとは

https://copilot.github.com/
AIがコメントや関数名から続きを推測して補完するツール

使い方

Technical Preview版のみの提供(2021-11-03時点)
GitHubから申込みが必要

対応IDE

  • VS Code
  • JetBrains系

お題1 Google画像検索で検索クエリから画像を集める

search_image_from_googくらいであと全部いけた


# download image from url
def download_image_from_url(url, file_name):
    import urllib.request
    with open(file_name, "wb") as image_file:
        image_file.write(urllib.request.urlopen(url).read())


# search image from google
def search_image_from_google(search_term, num_images):
    # import google image search api
    from google_images_download.google_images_download import google_images_download
    # pip install google_images_download
    response = google_images_download.googleimagesdownload()
    arguments = {"keywords": search_term, "limit": num_images, "print_urls": True}
    paths = response.download(arguments)
    return paths


if __name__ == '__main__':
    search_image_from_google("cat", 0)
    # download_image_from_url("https://www.google.com/imgres?imgurl=https%3A%2F%2Fwww.petfinder.com%2Fwp-content%2Fuploads%2F2012%2F11%2F91615172-find-a-lump-on-cats-skin-632x475.jpg&imgrefurl=https%3A%2F%2Fwww.petfinder.com%2Fphotos%2Fcats%2F91615172%2F&docid=1-_5Xq-XZJ_5M&tbnid=_5Xq-XZJ_5M_6M%3A&vet=10ahUKEwjNy7Ph39ndAhWGFYgKHdDcD8QMwhGKAMwAw..i&w=632&h=475&bih=938&biw=1920&q=cat&ved=0ahUKEwjNy7Ph39ndAhWGFYgKHdDcD8QMwhGKAMwAw&iact=mrc&uact=8")

お題2 明日の株価を予測する

predict_stock_price
get_stock_data
と打つだけで謎APIを叩いて学習して、推論するところまでできた



def get_stock_data(stock_symbol):
    """
    Get the stock data from the web.

    :return: A list of stock data.
    """
    import requests
    import json
    stock_data = []
    url = 'https://www.alphavantage.co/query?function=TIME_SERIES_DAILY&symbol=' + stock_symbol + '&apikey=自分で取得してください'
    response = requests.get(url)
    if response.status_code == 200:
        stock_data_raw = json.loads(response.text)
        stock_data = stock_data_raw['Time Series (Daily)']
    else:
        print('Error:', response.status_code)

    return stock_data


def predict_stock_price(stock_data):
    """
    Predict the stock price for the next day.

    :param stock_data: A list of stock data.
    :return: The predicted stock price.
    """
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    from sklearn.linear_model import LinearRegression
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import mean_squared_error
    # train stock data
    stock_data_df = pd.DataFrame.from_dict(stock_data, orient='index')
    stock_data_df.index = pd.to_datetime(stock_data_df.index)
    stock_data_df.columns = ['Open', 'High', 'Low', 'Close', 'Volume']
    stock_data_df['Date'] = stock_data_df.index
    stock_data_df['Date'] = stock_data_df['Date'].dt.strftime('%Y-%m-%d')
    stock_data_df['Date'] = pd.to_datetime(stock_data_df['Date'])
    stock_data_df = stock_data_df.sort_values(by='Date')
    stock_data_df = stock_data_df.reset_index(drop=True)
    stock_data_df = stock_data_df.drop(['Date'], axis=1)
    stock_data_df = stock_data_df.drop(['Volume'], axis=1)

    # train linear regression model
    X = stock_data_df.drop(['Close'], axis=1)
    y = stock_data_df['Close']
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
    regressor = LinearRegression()
    regressor.fit(X_train, y_train)
    y_pred = regressor.predict(X_test)
    mse = mean_squared_error(y_test, y_pred)
    print('MSE:', mse)
    # plot the data
    sns.set(style="darkgrid")
    plt.figure(figsize=(12, 8))
    plt.plot(y_test, color='blue', label='Actual')
    plt.plot(y_pred, color='red', label='Predicted')
    plt.title('Stock Price Prediction')
    plt.xlabel('Days')
    plt.ylabel('Stock Price')
    plt.legend()
    plt.show()
    return y_pred[-1]


if __name__ == '__main__':
    # get tesla data
    stock_symbol = 'TSLA'
    stock_data = get_stock_data(stock_symbol)
    predict_stock_price(stock_data)

お題3 犬と猫の分類をする

import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
train_cnn_model

すればあとはいい感じにやってくれた

import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import torch.nn as nn


def train_cnn_model(data_dir):
    """
    train CNN model to classify cats or dogs
    :param data_dir:
    :return:
    """
    model = models.resnet18(pretrained=True)
    model.fc = torch.nn.Linear(512, 2)
    # model.cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    transform = transforms.Compose([transforms.Resize(224),
                                    transforms.CenterCrop(224),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                         std=[0.229, 0.224, 0.225])])
    train_dataset = torchvision.datasets.ImageFolder(root=data_dir, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    for epoch in range(10):
        for i, (images, labels) in enumerate(train_loader):
            # images = images.cuda()
            # labels = labels.cuda()
            outputs = model(images)
            loss = criterion(outputs.float(), labels.to(torch.float32))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (i + 1) % 100 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                      .format(epoch + 1, 10, i + 1, len(train_dataset) // 32, loss.item()))

    torch.save(model.state_dict(), './model/resnet18.pth')
    return model


if __name__ == '__main__':
    train_cnn_model('downloads')
    print('Done')

まとめ

一部意図したものとは違う、動作しないものもあったが、基本タブ連打で実装できた。
いつも書く定番のコードとか書くときにすごく便利。
さらに、関数名を決めるときも提案されるようなものを選べば可読性もあがりそう。

注意点

たまに何も提案してくれないことがあるが、推論に時間がかかっているだけのこともある。ちょっと落ち着いて待ってみよう。

21
11
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
21
11

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?