Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
310 views
in Technique[技术] by (71.8m points)

python 3.x - How can I subsample an array according to its density? (Remove frequent values, keep rare ones)

I have this problem that I want to plot a data distribution where some values occur frequently while others are quite rare. The number of points in total is around 30.000. Rendering such a plot as png or (god forbid) pdf takes forever and the pdf is much too large to display.

So I want to subsample the data just for the plots. What I would like to achieve is to remove a lot of points where they overlap (where the density is high), but keep the ones where the density is low with almost probability 1.

Now, numpy.random.choice allows one to specify a vector of probabilities, which I've computed according to the data histogram with a few tweaks. But I can't seem to get my choice so that the rare points are really kept.

I've attached an image of the data; the right tail of the distribution has orders of magnitude fewer points, so I'd like to keep those. The data is 3d, but the density comes from only one dimension, so I can use that as a measure for how many points are in a given location

enter image description here

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

Consider the following function. It will bin the data in equal bins along the axis and

  • if there are one or two points in a bin, take over those points,
  • if there are more points in a bin, take over the minimum and maximum value.
  • append the first and last point to make sure the same data range is used.

This allows to keep the original data in regions of low density, but significantly reduce the amount of data to plot in regions of high density. At the same time all the features are preserved with a sufficiently dense binning.

import numpy as np; np.random.seed(42)

def filt(x,y, bins):
    d = np.digitize(x, bins)
    xfilt = []
    yfilt = []
    for i in np.unique(d):
        xi = x[d == i]
        yi = y[d == i]
        if len(xi) <= 2:
            xfilt.extend(list(xi))
            yfilt.extend(list(yi))
        else:
            xfilt.extend([xi[np.argmax(yi)], xi[np.argmin(yi)]])
            yfilt.extend([yi.max(), yi.min()])
    # prepend/append first/last point if necessary
    if x[0] != xfilt[0]:
        xfilt = [x[0]] + xfilt
        yfilt = [y[0]] + yfilt
    if x[-1] != xfilt[-1]:
        xfilt.append(x[-1])
        yfilt.append(y[-1])
    sort = np.argsort(xfilt)
    return np.array(xfilt)[sort], np.array(yfilt)[sort]

To illustrate the concept let's use some toy data

x = np.array([1,2,3,4, 6,7,8,9, 11,14, 17, 26,28,29])
y = np.array([4,2,5,3, 7,3,5,5, 2, 4,  5,  2,5,3])
bins = np.linspace(0,30,7)

Then calling xf, yf = filt(x,y,bins) and plotting both the original data and the filtered data gives:

enter image description here

The usecase of the question with some 30000 datapoints would be shown in the following. Using the presented technique would allow to reduce the number of plotted points from 30000 to some 500. This number will of course depend on the binning in use - here 300 bins. In this case the function takes ~10 ms to compute. This is not super-fast, but still a large improvement compared to plotting all the points.

import matplotlib.pyplot as plt

# Generate some data
x = np.sort(np.random.rayleigh(3, size=30000))
y = np.cumsum(np.random.randn(len(x)))+250
# Decide for a number of bins
bins = np.linspace(x.min(),x.max(),301)
# Filter data
xf, yf = filt(x,y,bins) 

# Plot results
fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, figsize=(7,8), 
                                    gridspec_kw=dict(height_ratios=[1,2,2]))

ax1.hist(x, bins=bins)
ax1.set_yscale("log")
ax1.set_yticks([1,10,100,1000])

ax2.plot(x,y, linewidth=1, label="original data, {} points".format(len(x)))

ax3.plot(xf, yf, linewidth=1, label="binned min/max, {} points".format(len(xf)))

for ax in [ax2, ax3]:
    ax.legend()
plt.show()

enter image description here


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...