今回はPysparkのデータフレームの数値をMatplotlibで可視化する方法を紹介したいと思います。
データの読み込み
環境はGooglecolabratoryを想定しています。
読み込むCSVはGooglecolabratoryのノートブックのノード内でデフォルトで配置されている
カリフォルニアの住宅情報のcsvになります。
#csvファイル読み込み
from pyspark.sql import SparkSession
filename = '/content/sample_data/california_housing_test.csv'
spark = SparkSession.builder \
.master("local") \
.appName("app") \
.getOrCreate()
data = spark.read.csv(filename, header=True, inferSchema=True, sep=',')
data.show()
グラフをプロットする
import matplotlib.pyplot as plt
x_ts = range(len( y_ans_val))
y_ans_val = [val['longitude'] for val in data.select('longitude').collect()]
plt.plot(x_ts, y_ans_val)
plt.ylabel('longitude')
plt.title('test plot')
流れとしてはsparkデータフレーム列をリストに変換し、それをmatplotlibでプロット対象に指定しています。
余談
一応データ数的に少量のものであれば、.toPandas()でsparkのデータフレームをpandasのデータフレームに変換してプロットするのもアリだと思います。
# sparkデータフレームをpandasデータフレームに変換する df = data.toPandas()
参照:https://stackoverflow.com/questions/52938842/how-to-plot-using-pyspark


コメント