Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
823 views
in Technique[技术] by (71.8m points)

arrays - Dynamic parameterization of Armadillo matrix dimensions in C++

The title summarizes the goal that is more exactly to dynamically retrieve the number of dimensions of MATLAB arrays passed to armadillo matrices.

I would like to change the second and third arguments of mY() and mD() to parametric ones below.

// mat(ptr_aux_mem, n_rows, n_cols, copy_aux_mem = true, strict = false)
arma::mat mY(&dY[0], 2, 168, false);
arma::mat mD(&dD[0], 2, 168, false);

This must be definitely a common use case, but I still could not find a nice way of achieving it for the general case when the number of dimensions of the arrays feeding from MATLAB could be arbitrary (n > 2).

For the matrix (two dimensional) case, I could possibly hack my way around but I feel like that is not elegant enough (probably not efficient either).

IMHO, the way to go must be:

matlab::data::TypedArray<double> has getDimensions() member function which retrieves matlab::data::ArrayDimensions that is fundamentally a std::vector<size_t>.

Indexing the first and second element of the vector retrieved by getDimensions() one can retrieve the number of rows and columns, for instance like below.

unsigned int mYrows = matrixY.getDimensions()[0];
unsigned int mYcols = matrixY.getDimensions()[1];

However, with my current setup, I cannot get to call getDimensions() through pointers/references in the foo() function of sub.cpp. If it is feasible, I would neither like to create additional temporary objects nor passing other arguments to foo(). How it possible that way?

Intuition keeps telling me that there must be an elegant solution that way too. Maybe using multiple indirection?

I would highly appreciate any help, hints or constructive comments from more knowledgeable SO members. Thank you in advance.

Setup:

Two C++ source files and a header file:

main.cpp

  • contains the general IO interface between MATLAB and C++
  • feeds two double arrays and two double const doubles into C++
  • it does some Armadillo based looping (this part is not that important therefore omitted) by calling foo()
  • returns outp which is a “just a plain” scalar double
  • Nothing fancy or complicated.

sub.cpp

  • This is only for the foo() looping part.

sub.hpp

  • Just a simple header file.
// main.cpp
// MATLAB API Header Files
#include "mex.hpp"
#include "mexAdapter.hpp"

// Custom header
#include "sub.hpp"

// Overloading the function call operator, thus class acts as a functor
class MexFunction : public matlab::mex::Function {
    public:
        void operator()(matlab::mex::ArgumentList outputs,
                        matlab::mex::ArgumentList inputs){
            
            matlab::data::ArrayFactory factory;
            // Validate arguments
            checkArguments(outputs, inputs);

            matlab::data::TypedArray<double> matrixY = std::move(inputs[0]);
            matlab::data::TypedArray<double> matrixD = std::move(inputs[1]);
            const double csT = inputs[2][0];
            const double csKy = inputs[3][0];

            buffer_ptr_t<double> mY = matrixY.release();
            buffer_ptr_t<double> mD = matrixD.release();

            double* darrY = mY.get();
            double* darrD = mD.get();

            // data type of outp is "just" a plain double, NOT a double array
            double outp = foo(darrY, darrD, csT, csKy);

            outputs[0] = factory.createScalar(outp);

            void checkArguments(matlab::mex::ArgumentList outputs, matlab::mex::ArgumentList inputs){
            // Create pointer to MATLAB engine
            std::shared_ptr<matlab::engine::MATLABEngine> matlabPtr = getEngine();
            // Create array factory, allows us to create MATLAB arrays in C++
            matlab::data::ArrayFactory factory;
            // Check input size and types
            if (inputs[0].getType() != ArrayType::DOUBLE ||
                inputs[0].getType() == ArrayType::COMPLEX_DOUBLE)
            {
                // Throw error directly into MATLAB if type does not match
                matlabPtr->feval(u"error", 0,
                    std::vector<Array>({ factory.createScalar("Input must be double array.") }));
            }
            // Check output size
            if (outputs.size() > 1) {
                matlabPtr->feval(u"error", 0, 
                    std::vector<Array>({ factory.createScalar("Only one output is returned.") }));
                }
        }
};

// sub.cpp

#include "sub.hpp"
#include "armadillo"

double foo(double* dY, double* dD, const double T, const double Ky) {
    
    double sum = 0;

    // Conversion of input parameters to Armadillo types
    // mat(ptr_aux_mem, n_rows, n_cols, copy_aux_mem = true, strict = false)
    arma::mat mY(&dY[0], 2, 168, false);
    arma::mat mD(&dD[0], 2, 168, false);

    // Armadillo calculations

    for(int t=0; t<int(T); t++){

        // some armadillo based calculation
        // each for cycle increments sum by its return value 
    }

    return sum;
}

// sub.hpp

#ifndef SUB_H_INCLUDED
#define SUB_H_INCLUDED

double foo(double* dY, double* dD, const double T, const double Ky);

#endif // SUB_H_INCLUDED
See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

One way is to convert it to arma matrix using a function

template<class T>
arma::Mat<T> getMat( matlab::data::TypedArray<T> A)
{
  matlab::data::TypedIterator<T> it = A.begin();
  matlab::data::ArrayDimensions nDim = A.getDimensions();
  return arma::Mat<T>(it.operator->(), nDim[0], nDim[1]);
}

and call it by

 arma::mat Y = getMat<double>(inputs[0]);
 arma::mat D = getMat<double>(inputs[1]);
 ...
 double outp = foo(Y,D, csT, csKy);

and change foo() to

double foo( arma::mat& dY, arma::mat& dD, const double T, const double Ky) 

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...