Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
717 views
in Technique[技术] by (71.8m points)

arrays - How to speed up reshape in higher rank tensor contraction by BLAS in Fortran?

Related question Fortran: Which method is faster to change the rank of arrays? (Reshape vs. Pointer)

If I have a tensor contraction A[a,b] * B[b,c,d] = C[a,c,d] If I use BLAS, I think I need DGEMM (assume real values), then I can

  1. first reshape tensor B[b,c,d] as D[b,e] where e = c*d,
  2. DGEMM, A[a,b] * D[b,e] = E[a,e]
  3. reshape E[a,e] into C[a,c,d]

The problem is, reshape is not that fast :( I saw discussions in Fortran: Which method is faster to change the rank of arrays? (Reshape vs. Pointer) , in the above link, the author met some error messages, except reshape itself.

Thus, I am asking if there is a convenient solution.

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

[I have prefaced the size of dimensions with the letter n to avoid confusion in the below between the tensor and the size of the tensor]

As discussed in the comments there is no need to reshape. Dgemm has no concept of tensors, it only knows about arrays. All it cares about is that those arrays are laid out in the correct order in memory. As Fortran is column major if you use a 3 dimensional array to represent the 3 dimensional tensor B in the question it will be laid out exactly the same in memory as a 2 dimensional array used to represent the 2 dimensional tensor D. As far as the matrix mult is concerned all you need to do now is get the dot products which form the result to be the right length. This leads you to the conclusion that if you tell dgemm that B has a leading dim of nb, and a second dim of nc*nd you will get the right result. This leads us to

ian@eris:~/work/stack$ gfortran --version
GNU Fortran (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
Copyright (C) 2017 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

ian@eris:~/work/stack$ cat reshape.f90
Program reshape_for_blas

  Use, Intrinsic :: iso_fortran_env, Only :  wp => real64, li => int64

  Implicit None

  Real( wp ), Dimension( :, :    ), Allocatable :: a
  Real( wp ), Dimension( :, :, : ), Allocatable :: b
  Real( wp ), Dimension( :, :, : ), Allocatable :: c1, c2
  Real( wp ), Dimension( :, :    ), Allocatable :: d
  Real( wp ), Dimension( :, :    ), Allocatable :: e

  Integer :: na, nb, nc, nd, ne
  
  Integer( li ) :: start, finish, rate

  Write( *, * ) 'na, nb, nc, nd ?'
  Read( *, * ) na, nb, nc, nd
  ne = nc * nd
  Allocate( a ( 1:na, 1:nb ) ) 
  Allocate( b ( 1:nb, 1:nc, 1:nd ) ) 
  Allocate( c1( 1:na, 1:nc, 1:nd ) ) 
  Allocate( c2( 1:na, 1:nc, 1:nd ) ) 
  Allocate( d ( 1:nb, 1:ne ) ) 
  Allocate( e ( 1:na, 1:ne ) ) 

  ! Set up some data
  Call Random_number( a )
  Call Random_number( b )

  ! With reshapes
  Call System_clock( start, rate )
  d = Reshape( b, Shape( d ) )
  Call dgemm( 'N', 'N', na, ne, nb, 1.0_wp, a, Size( a, Dim = 1 ), &
                                            d, Size( d, Dim = 1 ), &
                                    0.0_wp, e, Size( e, Dim = 1 ) )
  c1 = Reshape( e, Shape( c1 ) )
  Call System_clock( finish, rate )
  Write( *, * ) 'Time for reshaping method ', Real( finish - start, wp ) / rate
  
  ! Direct
  Call System_clock( start, rate )
  Call dgemm( 'N', 'N', na, ne, nb, 1.0_wp, a , Size( a , Dim = 1 ), &
                                            b , Size( b , Dim = 1 ), &
                                            0.0_wp, c2, Size( c2, Dim = 1 ) )
  Call System_clock( finish, rate )
  Write( *, * ) 'Time for straight  method ', Real( finish - start, wp ) / rate

  Write( *, * ) 'Difference between result matrices ', Maxval( Abs( c1 - c2 ) )

End Program reshape_for_blas
ian@eris:~/work/stack$ cat in
40 50 60 70
ian@eris:~/work/stack$ gfortran -std=f2008 -Wall -Wextra -fcheck=all reshape.f90  -lblas
ian@eris:~/work/stack$ ./a.out < in
 na, nb, nc, nd ?
 Time for reshaping method    1.0515256000000001E-002
 Time for straight  method    5.8608790000000003E-003
 Difference between result matrices    0.0000000000000000     
ian@eris:~/work/stack$ gfortran -std=f2008 -Wall -Wextra  reshape.f90  -lblas
ian@eris:~/work/stack$ ./a.out < in
 na, nb, nc, nd ?
 Time for reshaping method    1.3585931000000001E-002
 Time for straight  method    1.6730429999999999E-003
 Difference between result matrices    0.0000000000000000     

That said I think it worth noting though that the overhead for reshaping is O(N^2) while the time for the matrix multiply is O(N^3). Thus for large matrices the percentage overhead due to the reshape will tend to zero. Now code performance is not the only consideration, code readability and maintainability is also very important. So, if you find the reshape method much more readable and the matrices you use are sufficiently large that the overhead is not of import, you may well use the reshapes as in this case code readability might be more important than the performance. Your call.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...