矢量化径向基函数的多维特征的欧氏距离计算

时间:2018-06-17 21:49:46

标签: python numpy scikit-learn linear-algebra

我怀疑可能有一篇SO帖子回答了这个问题,但我还没有找到它,所以如果这是一个重复的问题,我会事先道歉。

我尝试使用Numpy从头开始实现径向基函数内核用于我自己的学习目的。对于一维输入,计算非常简单:

def kernel(x, y):
    return * np.exp( -0.5 * np.subtract.outer(x, y)**2)

以上内容来自blog post on Gaussian Processes

但是我试图将其扩展到多个维度。我有一个在以下工作正常的实现:

x = np.array([[4,3,5], [1,3,9], [0,1,0], [4,3,5]])
distances = []
γ = -.5
for i in x:
    for j in x:
        distances.append(np.exp(γ * np.linalg.norm(i - j) ** 2))
np.array(distances).reshape(len(x),len(x))

[[1.00000000e+00 3.72665317e-06 1.69189792e-10 1.00000000e+00]
 [3.72665317e-06 1.00000000e+00 2.11513104e-19 3.72665317e-06]
 [1.69189792e-10 2.11513104e-19 1.00000000e+00 1.69189792e-10]
 [1.00000000e+00 3.72665317e-06 1.69189792e-10 1.00000000e+00]]

我正在使用sklearn.pairwise.rbf_kernel

进行检查
from sklearn.metrics.pairwise import rbf_kernel
print(rbf_kernel(x, gamma= .5))

[[1.00000000e+00 3.72665317e-06 1.69189792e-10 1.00000000e+00]
 [3.72665317e-06 1.00000000e+00 2.11513104e-19 3.72665317e-06]
 [1.69189792e-10 2.11513104e-19 1.00000000e+00 1.69189792e-10]
 [1.00000000e+00 3.72665317e-06 1.69189792e-10 1.00000000e+00]]

但很明显,for for循环并不是迭代这个循环的最有效方法。什么是对此操作进行矢量化的最佳方法?

SO post提供了一种计算距离的有效方法,但不提供我需要的矢量化。

2 个答案:

答案 0 :(得分:1)

我们可以使用SciPy's cdist,然后使用指数值缩放 -

from scipy.spatial.distance import cdist

lam = -.5
out = np.exp(lam* cdist(x,x,'sqeuclidean'))

我们也可以leverage matrix-mutliplication -

def sqcdist_own(x):
    row_sum = (x**2).sum(1) # or np.einsum('ij,ij->i',x,x)
    sqeucdist = row_sum - 2*x.dot(x.T)
    sqeucdist += row_sum[:,None]
    return sqeucdist

out = np.exp(lam* cdist(x,x,'sqeuclidean'))

要在2D1D个案例中使用这些方法,请将x重新整理为2D作为预处理步骤:X = x.reshape(len(x),-1)然后使用{ {1}}而是作为这些解决方案的输入。

答案 1 :(得分:1)

您可以使用以下观察来解决问题:

import javafx.application.Application; import javafx.scene.Scene; import javafx.scene.image.Image; import javafx.scene.image.ImageView; import javafx.scene.layout.GridPane; import javafx.scene.layout.Pane; import javafx.scene.layout.StackPane; import javafx.stage.Stage; public class BoardView extends Application { // the dimensions of our background Image private final int BORDER_WIDTH = 695; private final int BORDER_HEIGHT = 720; @Override public void start(Stage stage) throws Exception { // Load your Image ImageView backgroundImageView = new ImageView( new Image("https://cdn.pixabay.com/photo/2013/07/13/10/24/board-157165_960_720.png")); // Initialize the grid GridPane boardGrid = initBoard(); // Set the dimensions of the grid boardGrid.setPrefSize(BORDER_WIDTH, BORDER_HEIGHT); // Use a StackPane to display the Image and the Grid StackPane mainPane = new StackPane(); mainPane.getChildren().addAll(backgroundImageView, boardGrid); stage.setScene(new Scene(mainPane)); stage.setResizable(false); stage.show(); } private GridPane initBoard() { GridPane boardGrid = new GridPane(); int tileNum = 8; double tileWidth = BORDER_WIDTH / tileNum; double tileHeight = BORDER_HEIGHT / tileNum; for (int i = 0; i < tileNum; i++) { for (int j = 0; j < tileNum; j++) { Tile tile = new Tile(i, j); // Set each 'Tile' the width and height tile.setPrefSize(tileWidth, tileHeight); // Add node on j column and i row boardGrid.add(tile, j, i); } } // Return the GridPane return boardGrid; } class Tile extends Pane { private int positionX; private int positionY; public Tile(int x, int y) { positionX = x; positionY = y; setOnMouseClicked(e -> { System.out.println(positionX + " " + positionY); }); } } public static void main(String[] args) { launch(args); } }

在代码中,它看起来如下:

public class Content extends AppCompatActivity {

    Button selectAnotherButton;
    TextView clickCountText;
    int getClickCountInt;

    private InterstitialAd mInterstitialAd;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_content);

        MobileAds.initialize(Content.this, "ca-app-pub-...");
        mInterstitialAd = new InterstitialAd(Content.this);
        mInterstitialAd.setAdUnitId("ca-app-pub-.../...");
        mInterstitialAd.loadAd(new AdRequest.Builder().build());

        final SharedPreferencesManager prefManager = SharedPreferencesManager.getInstance(Content.this);
        clickCountText = findViewById(R.id.click_count);
        clickCountText.setText(Integer.toString(prefManager.getClicks()));
        getClickCountInt = Integer.parseInt(clickCountText.getText().toString());

        selectAnotherButton = findViewById(R.id.button_select_another);

        selectAnotherButton.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View v) {
                getClickCountInt++;
                clickCountText.setText(Integer.toString(prefManager.increaseClickCount()));


                if(getClickCountInt % 3 == 0){

                    if (mInterstitialAd.isLoaded()) {
                        mInterstitialAd.show();
                    } else {
                        Log.d("ADVERT", "The interstitial wasn't loaded yet.");
                    }
                }

            }
        });


    }



}