commit 788747ec066234c16643595253ebe4d6bfeebe74
parent 27a0e0949c6ca3f7bd18569a23ddd0e1b3e9a64e
Author: Alex Auvolat <alex.auvolat@ens.fr>
Date: Fri, 10 Jul 2015 17:16:26 -0400
Adjust cluster_arrival.py to make it work again
Diffstat:
1 file changed, 17 insertions(+), 6 deletions(-)
diff --git a/data_analysis/cluster_arrival.py b/data_analysis/cluster_arrival.py
@@ -1,20 +1,31 @@
-import matplotlib.pyplot as plt
+#!/usr/bin/env python
import numpy
import cPickle
import scipy.misc
+import os
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets.samples_generator import make_blobs
from itertools import cycle
-print "Reading arrival point list"
-with open("arrivals.pkl") as f:
- pts = cPickle.load(f)
+import data
+from data.hdf5 import taxi_it
+from data.transformers import add_destination
+
+print "Generating arrival point list"
+dests = []
+for v in taxi_it("train"):
+ if len(v['latitude']) == 0: continue
+ dests.append([v['latitude'][-1], v['longitude'][-1]])
+pts = numpy.array(dests)
+
+with open(os.path.join(data.path, "arrivals.pkl"), "w") as f:
+ cPickle.dump(pts, f, protocol=cPickle.HIGHEST_PROTOCOL)
print "Doing clustering"
bw = estimate_bandwidth(pts, quantile=.1, n_samples=1000)
print bw
-bw = 0.001
+bw = 0.001 # (
ms = MeanShift(bandwidth=bw, bin_seeding=True, min_bin_freq=5)
ms.fit(pts)
@@ -22,6 +33,6 @@ cluster_centers = ms.cluster_centers_
print "Clusters shape: ", cluster_centers.shape
-with open("arrival-cluters.pkl", "w") as f:
+with open(os.path.join(data.path, "arrival-clusters.pkl"), "w") as f:
cPickle.dump(cluster_centers, f, protocol=cPickle.HIGHEST_PROTOCOL)