I don't think this version does what you wan't it to do.
deltas.d[:,:,ti+ti2-1] makes a copy so deltas.d will be unmodified by
gemm!. You can use sub as Viral mentions or you could try ArrayViews.jl.
Another thing you could consider is to use a Vector{Matrix{Float32}}
instead of Array{Float32,3}. It can be slightly unintuitive but if you
index a Vector{Matrix{Float32}} no copy is made.

Med venlig hilsen

Andreas Noack

2014-09-14 15:08 GMT-04:00 Michael Oliver <michael.d.oli...@gmail.com>:

> I was using axpy to replace the += only and doing the matrix muliply in
> the argument to axpy. But you're right gemm! is actually what I should be
> using (I'm just starting to learn the BLAS library names). Using gemm! the
> code is now 1.68x faster than my Matlab code (I mean a whole epoch of
> backprop training)! And down to 40% gc time. My goal of 2x speed up is in
> sight! I'll look into subArrays next.
>
> function errprop!(w::Array{Float32,3}, d::Array{Float32,3}, deltas)
> deltas.d[:] = 0.
> for ti=1:size(w,3), ti2 = 1:size(d,3)
>              Base.LinAlg.BLAS.gemm!('T', 'N', one(Float32), w[:,:,ti],
> d[:,:,ti2], one(Float32), deltas.d[:,:,ti+ti2-1])
> end
> deltas.d
> end
>
> On Sunday, September 14, 2014 2:18:07 AM UTC-7, Viral Shah wrote:
>>
>> Oh never mind - I see that you have a matrix multiply there that benefits
>> from calling BLAS. If it is a matrix multiply, how come you can get away
>> with axpy? Shouldn’t you need a gemm?
>>
>> Another way to avoid creating temporary arrays with indexing is to use
>> subArrays, which the linear algebra routines can work with.
>>
>> -viral
>>
>>
>>
>> > On 14-Sep-2014, at 2:43 pm, Viral Shah <vi...@mayin.org> wrote:
>> >
>> > That is great! However, by devectorizing, I meant writing the loop
>> statement itself as two more loops, so that you end up with 3 nested loops
>> effectively. You basically do not want all those w[:,:,ti] calls that
>> create matrices every time.
>> >
>> > You could also potentially hoist the deltas.d out of the loop. Try
>> something like:
>> >
>> >
>> > function errprop!(w::Array{Float32,3}, d::Array{Float32,3}, deltas)
>> >         deltas.d[:] = 0.
>> >         dd = deltas.d
>> >         for ti=1:size(w,3), ti2 = 1:size(d,3)
>> >                 for i=1:size(w,1)
>> >                         for j=size(w,2)
>> >                             dd[i,j,ti+ti2-1] += w[i,j,ti]'*d[i,j,ti2]
>> >                         end
>> >                 end
>> >         end
>> >         deltas.d
>> > end
>> >
>> >
>> > -viral
>> >
>> >
>> >
>> >> On 14-Sep-2014, at 12:47 pm, Michael Oliver <michael....@gmail.com>
>> wrote:
>> >>
>> >> Thanks Viral for the quick reply, that's good to know. I was able to
>> squeeze a little more performance out with axpy (see below). I tried
>> devectorizing the inner loop, but it was much slower, I believe because it
>> was no longer taking full advantage of MKL for the matrix multiply. So far
>> I've got the code running at 1.4x what I had in Matlab and according to
>> @time I still have 44.41% gc time. So 0.4 can't come soon enough! Great
>> work guys, I'm really enjoying learning Julia.
>> >>
>> >> function errprop!(w::Array{Float32,3}, d::Array{Float32,3}, deltas)
>> >>         deltas.d[:] = 0.
>> >>         rg =size(w,2)*size(d,2);
>> >>         for ti=1:size(w,3), ti2 = 1:size(d,3)
>> >>                   Base.LinAlg.BLAS.axpy!(1,w[:,:
>> ,ti]'*d[:,:,ti2],range(1,rg),deltas.d[:,:,ti+ti2-1],range(1,rg))
>> >>         end
>> >>         deltas.d
>> >> end
>> >>
>> >> On Saturday, September 13, 2014 10:10:25 PM UTC-7, Viral Shah wrote:
>> >> The garbage is generated from the indexing operations. In 0.4, we
>> should have array views that should solve this problem. For now, you can
>> either manually devectorize the inner loop, or use the @devectorize macros
>> in the Devectorize package, if they work out in this case.
>> >>
>> >> -viral
>> >>
>> >> On Sunday, September 14, 2014 10:34:45 AM UTC+5:30, Michael Oliver
>> wrote:
>> >> Hi all,
>> >> I've implemented a time delay neural network module and have been
>> trying to optimize it now. This function is for propagating the error
>> backwards through the network.
>> >> The deltas.d is just a container for holding the errors so I can do
>> things in place and don't have to keep initializing arrays. w and d are
>> collections of weights and errors respectively for different time lags.
>> >> This function gets called many many times and according to profiling,
>> there is a lot of garbage collection being induced by the fourth line,
>> specifically within multidimensional.jl getindex and setindex! and array.jl
>> +
>> >>
>> >> function errprop!(w::Array{Float32,3}, d::Array{Float32,3}, deltas)
>> >>         deltas.d[:] = 0.
>> >>         for ti=1:size(w,3), ti2 = 1:size(d,3)
>> >>             deltas.d[:,:,ti+ti2-1] += w[:,:,ti]'*d[:,:,ti2];
>> >>         end
>> >>         deltas.d
>> >> end
>> >>
>> >> Any advice would be much appreciated!
>> >> Best,
>> >> Michael
>> >
>>
>>

Reply via email to