Project

General

Profile

pme.c

Berk Hess, 07/16/2012 05:55 PM

 
1
/* -*- mode: c; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; c-file-style: "stroustrup"; -*-
2
 *
3
 *
4
 *                This source code is part of
5
 *
6
 *                 G   R   O   M   A   C   S
7
 *
8
 *          GROningen MAchine for Chemical Simulations
9
 *
10
 *                        VERSION 3.2.0
11
 * Written by David van der Spoel, Erik Lindahl, Berk Hess, and others.
12
 * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
13
 * Copyright (c) 2001-2004, The GROMACS development team,
14
 * check out http://www.gromacs.org for more information.
15

16
 * This program is free software; you can redistribute it and/or
17
 * modify it under the terms of the GNU General Public License
18
 * as published by the Free Software Foundation; either version 2
19
 * of the License, or (at your option) any later version.
20
 *
21
 * If you want to redistribute modifications, please consider that
22
 * scientific software is very special. Version control is crucial -
23
 * bugs must be traceable. We will be happy to consider code for
24
 * inclusion in the official distribution, but derived work must not
25
 * be called official GROMACS. Details are found in the README & COPYING
26
 * files - if they are missing, get the official version at www.gromacs.org.
27
 *
28
 * To help us fund GROMACS development, we humbly ask that you cite
29
 * the papers on the package - you can find them in the top README file.
30
 *
31
 * For more info, check our website at http://www.gromacs.org
32
 *
33
 * And Hey:
34
 * GROwing Monsters And Cloning Shrimps
35
 */
36
/* IMPORTANT FOR DEVELOPERS:
37
 *
38
 * Triclinic pme stuff isn't entirely trivial, and we've experienced
39
 * some bugs during development (many of them due to me). To avoid
40
 * this in the future, please check the following things if you make
41
 * changes in this file:
42
 *
43
 * 1. You should obtain identical (at least to the PME precision)
44
 *    energies, forces, and virial for
45
 *    a rectangular box and a triclinic one where the z (or y) axis is
46
 *    tilted a whole box side. For instance you could use these boxes:
47
 *
48
 *    rectangular       triclinic
49
 *     2  0  0           2  0  0
50
 *     0  2  0           0  2  0
51
 *     0  0  6           2  2  6
52
 *
53
 * 2. You should check the energy conservation in a triclinic box.
54
 *
55
 * It might seem an overkill, but better safe than sorry.
56
 * /Erik 001109
57
 */
58

    
59
#ifdef HAVE_CONFIG_H
60
#include <config.h>
61
#endif
62

    
63
#ifdef GMX_LIB_MPI
64
#include <mpi.h>
65
#endif
66
#ifdef GMX_THREAD_MPI
67
#include "tmpi.h"
68
#endif
69

    
70
#ifdef GMX_OPENMP
71
#include <omp.h>
72
#endif
73

    
74
#include <stdio.h>
75
#include <string.h>
76
#include <math.h>
77
#include <assert.h>
78
#include "typedefs.h"
79
#include "txtdump.h"
80
#include "vec.h"
81
#include "gmxcomplex.h"
82
#include "smalloc.h"
83
#include "futil.h"
84
#include "coulomb.h"
85
#include "gmx_fatal.h"
86
#include "pme.h"
87
#include "network.h"
88
#include "physics.h"
89
#include "nrnb.h"
90
#include "copyrite.h"
91
#include "gmx_wallcycle.h"
92
#include "gmx_parallel_3dfft.h"
93
#include "pdbio.h"
94
#include "gmx_cyclecounter.h"
95

    
96
/* Single precision, with SSE2 or higher available */
97
#if defined(GMX_X86_SSE2) && !defined(GMX_DOUBLE)
98

    
99
#include "gmx_x86_sse2.h"
100
#include "gmx_math_x86_sse2_single.h"
101

    
102
#define PME_SSE
103
/* Some old AMD processors could have problems with unaligned loads+stores */
104
#ifndef GMX_FAHCORE
105
#define PME_SSE_UNALIGNED
106
#endif
107
#endif
108

    
109
#include "mpelogging.h"
110

    
111
#define DFT_TOL 1e-7
112
/* #define PRT_FORCE */
113
/* conditions for on the fly time-measurement */
114
/* #define TAKETIME (step > 1 && timesteps < 10) */
115
#define TAKETIME FALSE
116

    
117
/* #define PME_TIME_THREADS */
118

    
119
#ifdef GMX_DOUBLE
120
#define mpi_type MPI_DOUBLE
121
#else
122
#define mpi_type MPI_FLOAT
123
#endif
124

    
125
/* GMX_CACHE_SEP should be a multiple of 16 to preserve alignment */
126
#define GMX_CACHE_SEP 64
127

    
128
/* We only define a maximum to be able to use local arrays without allocation.
129
 * An order larger than 12 should never be needed, even for test cases.
130
 * If needed it can be changed here.
131
 */
132
#define PME_ORDER_MAX 12
133

    
134
/* Internal datastructures */
135
typedef struct {
136
    int send_index0;
137
    int send_nindex;
138
    int recv_index0;
139
    int recv_nindex;
140
} pme_grid_comm_t;
141

    
142
typedef struct {
143
#ifdef GMX_MPI
144
    MPI_Comm mpi_comm;
145
#endif
146
    int  nnodes,nodeid;
147
    int  *s2g0;
148
    int  *s2g1;
149
    int  noverlap_nodes;
150
    int  *send_id,*recv_id;
151
    pme_grid_comm_t *comm_data;
152
    real *sendbuf;
153
    real *recvbuf;
154
} pme_overlap_t;
155

    
156
typedef struct {
157
    int *n;     /* Cumulative counts of the number of particles per thread */
158
    int nalloc; /* Allocation size of i */
159
    int *i;     /* Particle indices ordered on thread index (n) */
160
} thread_plist_t;
161

    
162
typedef struct {
163
    int  n;
164
    int  *ind;
165
    splinevec theta;
166
    splinevec dtheta;
167
} splinedata_t;
168

    
169
typedef struct {
170
    int  dimind;            /* The index of the dimension, 0=x, 1=y */
171
    int  nslab;
172
    int  nodeid;
173
#ifdef GMX_MPI
174
    MPI_Comm mpi_comm;
175
#endif
176

    
177
    int  *node_dest;        /* The nodes to send x and q to with DD */
178
    int  *node_src;         /* The nodes to receive x and q from with DD */
179
    int  *buf_index;        /* Index for commnode into the buffers */
180

    
181
    int  maxshift;
182

    
183
    int  npd;
184
    int  pd_nalloc;
185
    int  *pd;
186
    int  *count;            /* The number of atoms to send to each node */
187
    int  **count_thread;
188
    int  *rcount;           /* The number of atoms to receive */
189

    
190
    int  n;
191
    int  nalloc;
192
    rvec *x;
193
    real *q;
194
    rvec *f;
195
    gmx_bool bSpread;       /* These coordinates are used for spreading */
196
    int  pme_order;
197
    ivec *idx;
198
    rvec *fractx;            /* Fractional coordinate relative to the
199
                              * lower cell boundary
200
                              */
201
    int  nthread;
202
    int  *thread_idx;        /* Which thread should spread which charge */
203
    thread_plist_t *thread_plist;
204
    splinedata_t *spline;
205
} pme_atomcomm_t;
206

    
207
#define FLBS  3
208
#define FLBSZ 4
209

    
210
typedef struct {
211
    ivec ci;     /* The spatial location of this grid       */
212
    ivec n;      /* The size of *grid, including order-1    */
213
    ivec offset; /* The grid offset from the full node grid */
214
    int  order;  /* PME spreading order                     */
215
    real *grid;  /* The grid local thread, size n           */
216
} pmegrid_t;
217

    
218
typedef struct {
219
    pmegrid_t grid;     /* The full node grid (non thread-local)            */
220
    int  nthread;       /* The number of threads operating on this grid     */
221
    ivec nc;            /* The local spatial decomposition over the threads */
222
    pmegrid_t *grid_th; /* Array of grids for each thread                   */
223
    int  **g2t;         /* The grid to thread index                         */
224
    ivec nthread_comm;  /* The number of threads to communicate with        */
225
} pmegrids_t;
226

    
227

    
228
typedef struct {
229
#ifdef PME_SSE
230
    /* Masks for SSE aligned spreading and gathering */
231
    __m128 mask_SSE0[6],mask_SSE1[6];
232
#else
233
    int dummy; /* C89 requires that struct has at least one member */
234
#endif
235
} pme_spline_work_t;
236

    
237
typedef struct {
238
    /* work data for solve_pme */
239
    int      nalloc;
240
    real *   mhx;
241
    real *   mhy;
242
    real *   mhz;
243
    real *   m2;
244
    real *   denom;
245
    real *   tmp1_alloc;
246
    real *   tmp1;
247
    real *   eterm;
248
    real *   m2inv;
249

    
250
    real     energy;
251
    matrix   vir;
252
} pme_work_t;
253

    
254
typedef struct gmx_pme {
255
    int  ndecompdim;         /* The number of decomposition dimensions */
256
    int  nodeid;             /* Our nodeid in mpi->mpi_comm */
257
    int  nodeid_major;
258
    int  nodeid_minor;
259
    int  nnodes;             /* The number of nodes doing PME */
260
    int  nnodes_major;
261
    int  nnodes_minor;
262

    
263
    MPI_Comm mpi_comm;
264
    MPI_Comm mpi_comm_d[2];  /* Indexed on dimension, 0=x, 1=y */
265
#ifdef GMX_MPI
266
    MPI_Datatype  rvec_mpi;  /* the pme vector's MPI type */
267
#endif
268

    
269
    int  nthread;            /* The number of threads doing PME */
270

    
271
    gmx_bool bPPnode;        /* Node also does particle-particle forces */
272
    gmx_bool bFEP;           /* Compute Free energy contribution */
273
    int nkx,nky,nkz;         /* Grid dimensions */
274
    gmx_bool bP3M;           /* Do P3M: optimize the influence function */
275
    int pme_order;
276
    real epsilon_r;
277

    
278
    pmegrids_t pmegridA;  /* Grids on which we do spreading/interpolation, includes overlap */
279
    pmegrids_t pmegridB;
280
    /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
281
    int     pmegrid_nx,pmegrid_ny,pmegrid_nz;
282
    /* pmegrid_nz might be larger than strictly necessary to ensure
283
     * memory alignment, pmegrid_nz_base gives the real base size.
284
     */
285
    int     pmegrid_nz_base;
286
    /* The local PME grid starting indices */
287
    int     pmegrid_start_ix,pmegrid_start_iy,pmegrid_start_iz;
288

    
289
    /* Work data for spreading and gathering */
290
    pme_spline_work_t *spline_work;
291

    
292
    real *fftgridA;             /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
293
    real *fftgridB;             /* inside the interpolation grid, but separate for 2D PME decomp. */
294
    int   fftgrid_nx,fftgrid_ny,fftgrid_nz;
295

    
296
    t_complex *cfftgridA;             /* Grids for complex FFT data */
297
    t_complex *cfftgridB;
298
    int   cfftgrid_nx,cfftgrid_ny,cfftgrid_nz;
299

    
300
    gmx_parallel_3dfft_t  pfft_setupA;
301
    gmx_parallel_3dfft_t  pfft_setupB;
302

    
303
    int  *nnx,*nny,*nnz;
304
    real *fshx,*fshy,*fshz;
305

    
306
    pme_atomcomm_t atc[2];  /* Indexed on decomposition index */
307
    matrix    recipbox;
308
    splinevec bsp_mod;
309

    
310
    pme_overlap_t overlap[2]; /* Indexed on dimension, 0=x, 1=y */
311

    
312
    pme_atomcomm_t atc_energy; /* Only for gmx_pme_calc_energy */
313

    
314
    rvec *bufv;             /* Communication buffer */
315
    real *bufr;             /* Communication buffer */
316
    int  buf_nalloc;        /* The communication buffer size */
317

    
318
    /* thread local work data for solve_pme */
319
    pme_work_t *work;
320

    
321
    /* Work data for PME_redist */
322
    gmx_bool redist_init;
323
    int *    scounts;
324
    int *    rcounts;
325
    int *    sdispls;
326
    int *    rdispls;
327
    int *    sidx;
328
    int *    idxa;
329
    real *   redist_buf;
330
    int      redist_buf_nalloc;
331

    
332
    /* Work data for sum_qgrid */
333
    real *   sum_qgrid_tmp;
334
    real *   sum_qgrid_dd_tmp;
335
} t_gmx_pme;
336

    
337

    
338
static void calc_interpolation_idx(gmx_pme_t pme,pme_atomcomm_t *atc,
339
                                   int start,int end,int thread)
340
{
341
    int  i;
342
    int  *idxptr,tix,tiy,tiz;
343
    real *xptr,*fptr,tx,ty,tz;
344
    real rxx,ryx,ryy,rzx,rzy,rzz;
345
    int  nx,ny,nz;
346
    int  start_ix,start_iy,start_iz;
347
    int  *g2tx,*g2ty,*g2tz;
348
    gmx_bool bThreads;
349
    int  *thread_idx=NULL;
350
    thread_plist_t *tpl=NULL;
351
    int  *tpl_n=NULL;
352
    int  thread_i;
353

    
354
    nx  = pme->nkx;
355
    ny  = pme->nky;
356
    nz  = pme->nkz;
357

    
358
    start_ix = pme->pmegrid_start_ix;
359
    start_iy = pme->pmegrid_start_iy;
360
    start_iz = pme->pmegrid_start_iz;
361

    
362
    rxx = pme->recipbox[XX][XX];
363
    ryx = pme->recipbox[YY][XX];
364
    ryy = pme->recipbox[YY][YY];
365
    rzx = pme->recipbox[ZZ][XX];
366
    rzy = pme->recipbox[ZZ][YY];
367
    rzz = pme->recipbox[ZZ][ZZ];
368

    
369
    g2tx = pme->pmegridA.g2t[XX];
370
    g2ty = pme->pmegridA.g2t[YY];
371
    g2tz = pme->pmegridA.g2t[ZZ];
372

    
373
    bThreads = (atc->nthread > 1);
374
    if (bThreads)
375
    {
376
        thread_idx = atc->thread_idx;
377

    
378
        tpl   = &atc->thread_plist[thread];
379
        tpl_n = tpl->n;
380
        for(i=0; i<atc->nthread; i++)
381
        {
382
            tpl_n[i] = 0;
383
        }
384
    }
385

    
386
    for(i=start; i<end; i++) {
387
        xptr   = atc->x[i];
388
        idxptr = atc->idx[i];
389
        fptr   = atc->fractx[i];
390

    
391
        /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
392
        tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
393
        ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
394
        tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
395

    
396
        tix = (int)(tx);
397
        tiy = (int)(ty);
398
        tiz = (int)(tz);
399

    
400
        /* Because decomposition only occurs in x and y,
401
         * we never have a fraction correction in z.
402
         */
403
        fptr[XX] = tx - tix + pme->fshx[tix];
404
        fptr[YY] = ty - tiy + pme->fshy[tiy];
405
        fptr[ZZ] = tz - tiz;
406

    
407
        idxptr[XX] = pme->nnx[tix];
408
        idxptr[YY] = pme->nny[tiy];
409
        idxptr[ZZ] = pme->nnz[tiz];
410

    
411
#ifdef DEBUG
412
        range_check(idxptr[XX],0,pme->pmegrid_nx);
413
        range_check(idxptr[YY],0,pme->pmegrid_ny);
414
        range_check(idxptr[ZZ],0,pme->pmegrid_nz);
415
#endif
416

    
417
        if (bThreads)
418
        {
419
            thread_i = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
420
            thread_idx[i] = thread_i;
421
            tpl_n[thread_i]++;
422
        }
423
    }
424

    
425
    if (bThreads)
426
    {
427
        /* Make a list of particle indices sorted on thread */
428

    
429
        /* Get the cumulative count */
430
        for(i=1; i<atc->nthread; i++)
431
        {
432
            tpl_n[i] += tpl_n[i-1];
433
        }
434
        /* The current implementation distributes particles equally
435
         * over the threads, so we could actually allocate for that
436
         * in pme_realloc_atomcomm_things.
437
         */
438
        if (tpl_n[atc->nthread-1] > tpl->nalloc)
439
        {
440
            tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
441
            srenew(tpl->i,tpl->nalloc);
442
        }
443
        /* Set tpl_n to the cumulative start */
444
        for(i=atc->nthread-1; i>=1; i--)
445
        {
446
            tpl_n[i] = tpl_n[i-1];
447
        }
448
        tpl_n[0] = 0;
449

    
450
        /* Fill our thread local array with indices sorted on thread */
451
        for(i=start; i<end; i++)
452
        {
453
            tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
454
        }
455
        /* Now tpl_n contains the cummulative count again */
456
    }
457
}
458

    
459
static void make_thread_local_ind(pme_atomcomm_t *atc,
460
                                  int thread,splinedata_t *spline)
461
{
462
    int  n,t,i,start,end;
463
    thread_plist_t *tpl;
464

    
465
    /* Combine the indices made by each thread into one index */
466

    
467
    n = 0;
468
    start = 0;
469
    for(t=0; t<atc->nthread; t++)
470
    {
471
        tpl = &atc->thread_plist[t];
472
        /* Copy our part (start - end) from the list of thread t */
473
        if (thread > 0)
474
        {
475
            start = tpl->n[thread-1];
476
        }
477
        end = tpl->n[thread];
478
        for(i=start; i<end; i++)
479
        {
480
            spline->ind[n++] = tpl->i[i];
481
        }
482
    }
483

    
484
    spline->n = n;
485
}
486

    
487

    
488
static void pme_calc_pidx(int start, int end,
489
                          matrix recipbox, rvec x[],
490
                          pme_atomcomm_t *atc, int *count)
491
{
492
    int  nslab,i;
493
    int  si;
494
    real *xptr,s;
495
    real rxx,ryx,rzx,ryy,rzy;
496
    int *pd;
497

    
498
    /* Calculate PME task index (pidx) for each grid index.
499
     * Here we always assign equally sized slabs to each node
500
     * for load balancing reasons (the PME grid spacing is not used).
501
     */
502

    
503
    nslab = atc->nslab;
504
    pd    = atc->pd;
505

    
506
    /* Reset the count */
507
    for(i=0; i<nslab; i++)
508
    {
509
        count[i] = 0;
510
    }
511

    
512
    if (atc->dimind == 0)
513
    {
514
        rxx = recipbox[XX][XX];
515
        ryx = recipbox[YY][XX];
516
        rzx = recipbox[ZZ][XX];
517
        /* Calculate the node index in x-dimension */
518
        for(i=start; i<end; i++)
519
        {
520
            xptr   = x[i];
521
            /* Fractional coordinates along box vectors */
522
            s = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
523
            si = (int)(s + 2*nslab) % nslab;
524
            pd[i] = si;
525
            count[si]++;
526
        }
527
    }
528
    else
529
    {
530
        ryy = recipbox[YY][YY];
531
        rzy = recipbox[ZZ][YY];
532
        /* Calculate the node index in y-dimension */
533
        for(i=start; i<end; i++)
534
        {
535
            xptr   = x[i];
536
            /* Fractional coordinates along box vectors */
537
            s = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
538
            si = (int)(s + 2*nslab) % nslab;
539
            pd[i] = si;
540
            count[si]++;
541
        }
542
    }
543
}
544

    
545
static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
546
                                  pme_atomcomm_t *atc)
547
{
548
    int nthread,thread,slab;
549

    
550
    nthread = atc->nthread;
551

    
552
#pragma omp parallel for num_threads(nthread) schedule(static)
553
    for(thread=0; thread<nthread; thread++)
554
    {
555
        pme_calc_pidx(natoms* thread   /nthread,
556
                      natoms*(thread+1)/nthread,
557
                      recipbox,x,atc,atc->count_thread[thread]);
558
    }
559
    /* Non-parallel reduction, since nslab is small */
560

    
561
    for(thread=1; thread<nthread; thread++)
562
    {
563
        for(slab=0; slab<atc->nslab; slab++)
564
        {
565
            atc->count_thread[0][slab] += atc->count_thread[thread][slab];
566
        }
567
    }
568
}
569

    
570
static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
571
{
572
    int i,d;
573

    
574
    srenew(spline->ind,atc->nalloc);
575
    /* Initialize the index to identity so it works without threads */
576
    for(i=0; i<atc->nalloc; i++)
577
    {
578
        spline->ind[i] = i;
579
    }
580

    
581
    for(d=0;d<DIM;d++)
582
    {
583
        srenew(spline->theta[d] ,atc->pme_order*atc->nalloc);
584
        srenew(spline->dtheta[d],atc->pme_order*atc->nalloc);
585
    }
586
}
587

    
588
static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
589
{
590
    int nalloc_old,i,j,nalloc_tpl;
591

    
592
    /* We have to avoid a NULL pointer for atc->x to avoid
593
     * possible fatal errors in MPI routines.
594
     */
595
    if (atc->n > atc->nalloc || atc->nalloc == 0)
596
    {
597
        nalloc_old = atc->nalloc;
598
        atc->nalloc = over_alloc_dd(max(atc->n,1));
599

    
600
        if (atc->nslab > 1) {
601
            srenew(atc->x,atc->nalloc);
602
            srenew(atc->q,atc->nalloc);
603
            srenew(atc->f,atc->nalloc);
604
            for(i=nalloc_old; i<atc->nalloc; i++)
605
            {
606
                clear_rvec(atc->f[i]);
607
            }
608
        }
609
        if (atc->bSpread) {
610
            srenew(atc->fractx,atc->nalloc);
611
            srenew(atc->idx   ,atc->nalloc);
612

    
613
            if (atc->nthread > 1)
614
            {
615
                srenew(atc->thread_idx,atc->nalloc);
616
            }
617

    
618
            for(i=0; i<atc->nthread; i++)
619
            {
620
                pme_realloc_splinedata(&atc->spline[i],atc);
621
            }
622
        }
623
    }
624
}
625

    
626
static void pmeredist_pd(gmx_pme_t pme, gmx_bool forw,
627
                         int n, gmx_bool bXF, rvec *x_f, real *charge,
628
                         pme_atomcomm_t *atc)
629
/* Redistribute particle data for PME calculation */
630
/* domain decomposition by x coordinate           */
631
{
632
    int *idxa;
633
    int i, ii;
634

    
635
    if(FALSE == pme->redist_init) {
636
        snew(pme->scounts,atc->nslab);
637
        snew(pme->rcounts,atc->nslab);
638
        snew(pme->sdispls,atc->nslab);
639
        snew(pme->rdispls,atc->nslab);
640
        snew(pme->sidx,atc->nslab);
641
        pme->redist_init = TRUE;
642
    }
643
    if (n > pme->redist_buf_nalloc) {
644
        pme->redist_buf_nalloc = over_alloc_dd(n);
645
        srenew(pme->redist_buf,pme->redist_buf_nalloc*DIM);
646
    }
647

    
648
    pme->idxa = atc->pd;
649

    
650
#ifdef GMX_MPI
651
    if (forw && bXF) {
652
        /* forward, redistribution from pp to pme */
653

    
654
        /* Calculate send counts and exchange them with other nodes */
655
        for(i=0; (i<atc->nslab); i++) pme->scounts[i]=0;
656
        for(i=0; (i<n); i++) pme->scounts[pme->idxa[i]]++;
657
        MPI_Alltoall( pme->scounts, 1, MPI_INT, pme->rcounts, 1, MPI_INT, atc->mpi_comm);
658

    
659
        /* Calculate send and receive displacements and index into send
660
           buffer */
661
        pme->sdispls[0]=0;
662
        pme->rdispls[0]=0;
663
        pme->sidx[0]=0;
664
        for(i=1; i<atc->nslab; i++) {
665
            pme->sdispls[i]=pme->sdispls[i-1]+pme->scounts[i-1];
666
            pme->rdispls[i]=pme->rdispls[i-1]+pme->rcounts[i-1];
667
            pme->sidx[i]=pme->sdispls[i];
668
        }
669
        /* Total # of particles to be received */
670
        atc->n = pme->rdispls[atc->nslab-1] + pme->rcounts[atc->nslab-1];
671

    
672
        pme_realloc_atomcomm_things(atc);
673

    
674
        /* Copy particle coordinates into send buffer and exchange*/
675
        for(i=0; (i<n); i++) {
676
            ii=DIM*pme->sidx[pme->idxa[i]];
677
            pme->sidx[pme->idxa[i]]++;
678
            pme->redist_buf[ii+XX]=x_f[i][XX];
679
            pme->redist_buf[ii+YY]=x_f[i][YY];
680
            pme->redist_buf[ii+ZZ]=x_f[i][ZZ];
681
        }
682
        MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls,
683
                      pme->rvec_mpi, atc->x, pme->rcounts, pme->rdispls,
684
                      pme->rvec_mpi, atc->mpi_comm);
685
    }
686
    if (forw) {
687
        /* Copy charge into send buffer and exchange*/
688
        for(i=0; i<atc->nslab; i++) pme->sidx[i]=pme->sdispls[i];
689
        for(i=0; (i<n); i++) {
690
            ii=pme->sidx[pme->idxa[i]];
691
            pme->sidx[pme->idxa[i]]++;
692
            pme->redist_buf[ii]=charge[i];
693
        }
694
        MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls, mpi_type,
695
                      atc->q, pme->rcounts, pme->rdispls, mpi_type,
696
                      atc->mpi_comm);
697
    }
698
    else { /* backward, redistribution from pme to pp */
699
        MPI_Alltoallv(atc->f, pme->rcounts, pme->rdispls, pme->rvec_mpi,
700
                      pme->redist_buf, pme->scounts, pme->sdispls,
701
                      pme->rvec_mpi, atc->mpi_comm);
702

    
703
        /* Copy data from receive buffer */
704
        for(i=0; i<atc->nslab; i++)
705
            pme->sidx[i] = pme->sdispls[i];
706
        for(i=0; (i<n); i++) {
707
            ii = DIM*pme->sidx[pme->idxa[i]];
708
            x_f[i][XX] += pme->redist_buf[ii+XX];
709
            x_f[i][YY] += pme->redist_buf[ii+YY];
710
            x_f[i][ZZ] += pme->redist_buf[ii+ZZ];
711
            pme->sidx[pme->idxa[i]]++;
712
        }
713
    }
714
#endif
715
}
716

    
717
static void pme_dd_sendrecv(pme_atomcomm_t *atc,
718
                            gmx_bool bBackward,int shift,
719
                            void *buf_s,int nbyte_s,
720
                            void *buf_r,int nbyte_r)
721
{
722
#ifdef GMX_MPI
723
    int dest,src;
724
    MPI_Status stat;
725

    
726
    if (bBackward == FALSE) {
727
        dest = atc->node_dest[shift];
728
        src  = atc->node_src[shift];
729
    } else {
730
        dest = atc->node_src[shift];
731
        src  = atc->node_dest[shift];
732
    }
733

    
734
    if (nbyte_s > 0 && nbyte_r > 0) {
735
        MPI_Sendrecv(buf_s,nbyte_s,MPI_BYTE,
736
                     dest,shift,
737
                     buf_r,nbyte_r,MPI_BYTE,
738
                     src,shift,
739
                     atc->mpi_comm,&stat);
740
    } else if (nbyte_s > 0) {
741
        MPI_Send(buf_s,nbyte_s,MPI_BYTE,
742
                 dest,shift,
743
                 atc->mpi_comm);
744
    } else if (nbyte_r > 0) {
745
        MPI_Recv(buf_r,nbyte_r,MPI_BYTE,
746
                 src,shift,
747
                 atc->mpi_comm,&stat);
748
    }
749
#endif
750
}
751

    
752
static void dd_pmeredist_x_q(gmx_pme_t pme,
753
                             int n, gmx_bool bX, rvec *x, real *charge,
754
                             pme_atomcomm_t *atc)
755
{
756
    int *commnode,*buf_index;
757
    int nnodes_comm,i,nsend,local_pos,buf_pos,node,scount,rcount;
758

    
759
    commnode  = atc->node_dest;
760
    buf_index = atc->buf_index;
761

    
762
    nnodes_comm = min(2*atc->maxshift,atc->nslab-1);
763

    
764
    nsend = 0;
765
    for(i=0; i<nnodes_comm; i++) {
766
        buf_index[commnode[i]] = nsend;
767
        nsend += atc->count[commnode[i]];
768
    }
769
    if (bX) {
770
        if (atc->count[atc->nodeid] + nsend != n)
771
            gmx_fatal(FARGS,"%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
772
                      "This usually means that your system is not well equilibrated.",
773
                      n - (atc->count[atc->nodeid] + nsend),
774
                      pme->nodeid,'x'+atc->dimind);
775

    
776
        if (nsend > pme->buf_nalloc) {
777
            pme->buf_nalloc = over_alloc_dd(nsend);
778
            srenew(pme->bufv,pme->buf_nalloc);
779
            srenew(pme->bufr,pme->buf_nalloc);
780
        }
781

    
782
        atc->n = atc->count[atc->nodeid];
783
        for(i=0; i<nnodes_comm; i++) {
784
            scount = atc->count[commnode[i]];
785
            /* Communicate the count */
786
            if (debug)
787
                fprintf(debug,"dimind %d PME node %d send to node %d: %d\n",
788
                        atc->dimind,atc->nodeid,commnode[i],scount);
789
            pme_dd_sendrecv(atc,FALSE,i,
790
                            &scount,sizeof(int),
791
                            &atc->rcount[i],sizeof(int));
792
            atc->n += atc->rcount[i];
793
        }
794

    
795
        pme_realloc_atomcomm_things(atc);
796
    }
797

    
798
    local_pos = 0;
799
    for(i=0; i<n; i++) {
800
        node = atc->pd[i];
801
        if (node == atc->nodeid) {
802
            /* Copy direct to the receive buffer */
803
            if (bX) {
804
                copy_rvec(x[i],atc->x[local_pos]);
805
            }
806
            atc->q[local_pos] = charge[i];
807
            local_pos++;
808
        } else {
809
            /* Copy to the send buffer */
810
            if (bX) {
811
                copy_rvec(x[i],pme->bufv[buf_index[node]]);
812
            }
813
            pme->bufr[buf_index[node]] = charge[i];
814
            buf_index[node]++;
815
        }
816
    }
817

    
818
    buf_pos = 0;
819
    for(i=0; i<nnodes_comm; i++) {
820
        scount = atc->count[commnode[i]];
821
        rcount = atc->rcount[i];
822
        if (scount > 0 || rcount > 0) {
823
            if (bX) {
824
                /* Communicate the coordinates */
825
                pme_dd_sendrecv(atc,FALSE,i,
826
                                pme->bufv[buf_pos],scount*sizeof(rvec),
827
                                atc->x[local_pos],rcount*sizeof(rvec));
828
            }
829
            /* Communicate the charges */
830
            pme_dd_sendrecv(atc,FALSE,i,
831
                            pme->bufr+buf_pos,scount*sizeof(real),
832
                            atc->q+local_pos,rcount*sizeof(real));
833
            buf_pos   += scount;
834
            local_pos += atc->rcount[i];
835
        }
836
    }
837
}
838

    
839
static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
840
                           int n, rvec *f,
841
                           gmx_bool bAddF)
842
{
843
  int *commnode,*buf_index;
844
  int nnodes_comm,local_pos,buf_pos,i,scount,rcount,node;
845

    
846
  commnode  = atc->node_dest;
847
  buf_index = atc->buf_index;
848

    
849
  nnodes_comm = min(2*atc->maxshift,atc->nslab-1);
850

    
851
  local_pos = atc->count[atc->nodeid];
852
  buf_pos = 0;
853
  for(i=0; i<nnodes_comm; i++) {
854
    scount = atc->rcount[i];
855
    rcount = atc->count[commnode[i]];
856
    if (scount > 0 || rcount > 0) {
857
      /* Communicate the forces */
858
      pme_dd_sendrecv(atc,TRUE,i,
859
                      atc->f[local_pos],scount*sizeof(rvec),
860
                      pme->bufv[buf_pos],rcount*sizeof(rvec));
861
      local_pos += scount;
862
    }
863
    buf_index[commnode[i]] = buf_pos;
864
    buf_pos   += rcount;
865
  }
866

    
867
    local_pos = 0;
868
    if (bAddF)
869
    {
870
        for(i=0; i<n; i++)
871
        {
872
            node = atc->pd[i];
873
            if (node == atc->nodeid)
874
            {
875
                /* Add from the local force array */
876
                rvec_inc(f[i],atc->f[local_pos]);
877
                local_pos++;
878
            }
879
            else
880
            {
881
                /* Add from the receive buffer */
882
                rvec_inc(f[i],pme->bufv[buf_index[node]]);
883
                buf_index[node]++;
884
            }
885
        }
886
    }
887
    else
888
    {
889
        for(i=0; i<n; i++)
890
        {
891
            node = atc->pd[i];
892
            if (node == atc->nodeid)
893
            {
894
                /* Copy from the local force array */
895
                copy_rvec(atc->f[local_pos],f[i]);
896
                local_pos++;
897
            }
898
            else
899
            {
900
                /* Copy from the receive buffer */
901
                copy_rvec(pme->bufv[buf_index[node]],f[i]);
902
                buf_index[node]++;
903
            }
904
        }
905
    }
906
}
907

    
908
#ifdef GMX_MPI
909
static void
910
gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
911
{
912
    pme_overlap_t *overlap;
913
    int send_index0,send_nindex;
914
    int recv_index0,recv_nindex;
915
    MPI_Status stat;
916
    int i,j,k,ix,iy,iz,icnt;
917
    int ipulse,send_id,recv_id,datasize;
918
    real *p;
919
    real *sendptr,*recvptr;
920

    
921
    /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
922
    overlap = &pme->overlap[1];
923

    
924
    for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
925
    {
926
        /* Since we have already (un)wrapped the overlap in the z-dimension,
927
         * we only have to communicate 0 to nkz (not pmegrid_nz).
928
         */
929
        if (direction==GMX_SUM_QGRID_FORWARD)
930
        {
931
            send_id = overlap->send_id[ipulse];
932
            recv_id = overlap->recv_id[ipulse];
933
            send_index0   = overlap->comm_data[ipulse].send_index0;
934
            send_nindex   = overlap->comm_data[ipulse].send_nindex;
935
            recv_index0   = overlap->comm_data[ipulse].recv_index0;
936
            recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
937
        }
938
        else
939
        {
940
            send_id = overlap->recv_id[ipulse];
941
            recv_id = overlap->send_id[ipulse];
942
            send_index0   = overlap->comm_data[ipulse].recv_index0;
943
            send_nindex   = overlap->comm_data[ipulse].recv_nindex;
944
            recv_index0   = overlap->comm_data[ipulse].send_index0;
945
            recv_nindex   = overlap->comm_data[ipulse].send_nindex;
946
        }
947

    
948
        /* Copy data to contiguous send buffer */
949
        if (debug)
950
        {
951
            fprintf(debug,"PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
952
                    pme->nodeid,overlap->nodeid,send_id,
953
                    pme->pmegrid_start_iy,
954
                    send_index0-pme->pmegrid_start_iy,
955
                    send_index0-pme->pmegrid_start_iy+send_nindex);
956
        }
957
        icnt = 0;
958
        for(i=0;i<pme->pmegrid_nx;i++)
959
        {
960
            ix = i;
961
            for(j=0;j<send_nindex;j++)
962
            {
963
                iy = j + send_index0 - pme->pmegrid_start_iy;
964
                for(k=0;k<pme->nkz;k++)
965
                {
966
                    iz = k;
967
                    overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
968
                }
969
            }
970
        }
971

    
972
        datasize      = pme->pmegrid_nx * pme->nkz;
973

    
974
        MPI_Sendrecv(overlap->sendbuf,send_nindex*datasize,GMX_MPI_REAL,
975
                     send_id,ipulse,
976
                     overlap->recvbuf,recv_nindex*datasize,GMX_MPI_REAL,
977
                     recv_id,ipulse,
978
                     overlap->mpi_comm,&stat);
979

    
980
        /* Get data from contiguous recv buffer */
981
        if (debug)
982
        {
983
            fprintf(debug,"PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
984
                    pme->nodeid,overlap->nodeid,recv_id,
985
                    pme->pmegrid_start_iy,
986
                    recv_index0-pme->pmegrid_start_iy,
987
                    recv_index0-pme->pmegrid_start_iy+recv_nindex);
988
        }
989
        icnt = 0;
990
        for(i=0;i<pme->pmegrid_nx;i++)
991
        {
992
            ix = i;
993
            for(j=0;j<recv_nindex;j++)
994
            {
995
                iy = j + recv_index0 - pme->pmegrid_start_iy;
996
                for(k=0;k<pme->nkz;k++)
997
                {
998
                    iz = k;
999
                    if(direction==GMX_SUM_QGRID_FORWARD)
1000
                    {
1001
                        grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
1002
                    }
1003
                    else
1004
                    {
1005
                        grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1006
                    }
1007
                }
1008
            }
1009
        }
1010
    }
1011

    
1012
    /* Major dimension is easier, no copying required,
1013
     * but we might have to sum to separate array.
1014
     * Since we don't copy, we have to communicate up to pmegrid_nz,
1015
     * not nkz as for the minor direction.
1016
     */
1017
    overlap = &pme->overlap[0];
1018

    
1019
    for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
1020
    {
1021
        if(direction==GMX_SUM_QGRID_FORWARD)
1022
        {
1023
            send_id = overlap->send_id[ipulse];
1024
            recv_id = overlap->recv_id[ipulse];
1025
            send_index0   = overlap->comm_data[ipulse].send_index0;
1026
            send_nindex   = overlap->comm_data[ipulse].send_nindex;
1027
            recv_index0   = overlap->comm_data[ipulse].recv_index0;
1028
            recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1029
            recvptr   = overlap->recvbuf;
1030
        }
1031
        else
1032
        {
1033
            send_id = overlap->recv_id[ipulse];
1034
            recv_id = overlap->send_id[ipulse];
1035
            send_index0   = overlap->comm_data[ipulse].recv_index0;
1036
            send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1037
            recv_index0   = overlap->comm_data[ipulse].send_index0;
1038
            recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1039
            recvptr   = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1040
        }
1041

    
1042
        sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1043
        datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1044

    
1045
        if (debug)
1046
        {
1047
            fprintf(debug,"PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1048
                    pme->nodeid,overlap->nodeid,send_id,
1049
                    pme->pmegrid_start_ix,
1050
                    send_index0-pme->pmegrid_start_ix,
1051
                    send_index0-pme->pmegrid_start_ix+send_nindex);
1052
            fprintf(debug,"PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1053
                    pme->nodeid,overlap->nodeid,recv_id,
1054
                    pme->pmegrid_start_ix,
1055
                    recv_index0-pme->pmegrid_start_ix,
1056
                    recv_index0-pme->pmegrid_start_ix+recv_nindex);
1057
        }
1058

    
1059
        MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
1060
                     send_id,ipulse,
1061
                     recvptr,recv_nindex*datasize,GMX_MPI_REAL,
1062
                     recv_id,ipulse,
1063
                     overlap->mpi_comm,&stat);
1064

    
1065
        /* ADD data from contiguous recv buffer */
1066
        if(direction==GMX_SUM_QGRID_FORWARD)
1067
        {
1068
            p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1069
            for(i=0;i<recv_nindex*datasize;i++)
1070
            {
1071
                p[i] += overlap->recvbuf[i];
1072
            }
1073
        }
1074
    }
1075
}
1076
#endif
1077

    
1078

    
1079
static int
1080
copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid)
1081
{
1082
    ivec    local_fft_ndata,local_fft_offset,local_fft_size;
1083
    ivec    local_pme_size;
1084
    int     i,ix,iy,iz;
1085
    int     pmeidx,fftidx;
1086

    
1087
    /* Dimensions should be identical for A/B grid, so we just use A here */
1088
    gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1089
                                   local_fft_ndata,
1090
                                   local_fft_offset,
1091
                                   local_fft_size);
1092

    
1093
    local_pme_size[0] = pme->pmegrid_nx;
1094
    local_pme_size[1] = pme->pmegrid_ny;
1095
    local_pme_size[2] = pme->pmegrid_nz;
1096

    
1097
    /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1098
     the offset is identical, and the PME grid always has more data (due to overlap)
1099
     */
1100
    {
1101
#ifdef DEBUG_PME
1102
        FILE *fp,*fp2;
1103
        char fn[STRLEN],format[STRLEN];
1104
        real val;
1105
        sprintf(fn,"pmegrid%d.pdb",pme->nodeid);
1106
        fp = ffopen(fn,"w");
1107
        sprintf(fn,"pmegrid%d.txt",pme->nodeid);
1108
        fp2 = ffopen(fn,"w");
1109
     sprintf(format,"%s%s\n",pdbformat,"%6.2f%6.2f");
1110
#endif
1111

    
1112
    for(ix=0;ix<local_fft_ndata[XX];ix++)
1113
    {
1114
        for(iy=0;iy<local_fft_ndata[YY];iy++)
1115
        {
1116
            for(iz=0;iz<local_fft_ndata[ZZ];iz++)
1117
            {
1118
                pmeidx = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1119
                fftidx = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1120
                fftgrid[fftidx] = pmegrid[pmeidx];
1121
#ifdef DEBUG_PME
1122
                val = 100*pmegrid[pmeidx];
1123
                if (pmegrid[pmeidx] != 0)
1124
                fprintf(fp,format,"ATOM",pmeidx,"CA","GLY",' ',pmeidx,' ',
1125
                        5.0*ix,5.0*iy,5.0*iz,1.0,val);
1126
                if (pmegrid[pmeidx] != 0)
1127
                    fprintf(fp2,"%-12s  %5d  %5d  %5d  %12.5e\n",
1128
                            "qgrid",
1129
                            pme->pmegrid_start_ix + ix,
1130
                            pme->pmegrid_start_iy + iy,
1131
                            pme->pmegrid_start_iz + iz,
1132
                            pmegrid[pmeidx]);
1133
#endif
1134
            }
1135
        }
1136
    }
1137
#ifdef DEBUG_PME
1138
    ffclose(fp);
1139
    ffclose(fp2);
1140
#endif
1141
    }
1142
    return 0;
1143
}
1144

    
1145

    
1146
static gmx_cycles_t omp_cyc_start()
1147
{
1148
    return gmx_cycles_read();
1149
}
1150

    
1151
static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1152
{
1153
    return gmx_cycles_read() - c;
1154
}
1155

    
1156

    
1157
static int
1158
copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid,
1159
                        int nthread,int thread)
1160
{
1161
    ivec    local_fft_ndata,local_fft_offset,local_fft_size;
1162
    ivec    local_pme_size;
1163
    int     ixy0,ixy1,ixy,ix,iy,iz;
1164
    int     pmeidx,fftidx;
1165
#ifdef PME_TIME_THREADS
1166
    gmx_cycles_t c1;
1167
    static double cs1=0;
1168
    static int cnt=0;
1169
#endif
1170

    
1171
#ifdef PME_TIME_THREADS
1172
    c1 = omp_cyc_start();
1173
#endif
1174
    /* Dimensions should be identical for A/B grid, so we just use A here */
1175
    gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1176
                                   local_fft_ndata,
1177
                                   local_fft_offset,
1178
                                   local_fft_size);
1179

    
1180
    local_pme_size[0] = pme->pmegrid_nx;
1181
    local_pme_size[1] = pme->pmegrid_ny;
1182
    local_pme_size[2] = pme->pmegrid_nz;
1183

    
1184
    /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1185
     the offset is identical, and the PME grid always has more data (due to overlap)
1186
     */
1187
    ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1188
    ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1189

    
1190
    for(ixy=ixy0;ixy<ixy1;ixy++)
1191
    {
1192
        ix = ixy/local_fft_ndata[YY];
1193
        iy = ixy - ix*local_fft_ndata[YY];
1194

    
1195
        pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1196
        fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1197
        for(iz=0;iz<local_fft_ndata[ZZ];iz++)
1198
        {
1199
            pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1200
        }
1201
    }
1202

    
1203
#ifdef PME_TIME_THREADS
1204
    c1 = omp_cyc_end(c1);
1205
    cs1 += (double)c1;
1206
    cnt++;
1207
    if (cnt % 20 == 0)
1208
    {
1209
        printf("copy %.2f\n",cs1*1e-9);
1210
    }
1211
#endif
1212

    
1213
    return 0;
1214
}
1215

    
1216

    
1217
static void
1218
wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1219
{
1220
    int     nx,ny,nz,pnx,pny,pnz,ny_x,overlap,ix,iy,iz;
1221

    
1222
    nx = pme->nkx;
1223
    ny = pme->nky;
1224
    nz = pme->nkz;
1225

    
1226
    pnx = pme->pmegrid_nx;
1227
    pny = pme->pmegrid_ny;
1228
    pnz = pme->pmegrid_nz;
1229

    
1230
    overlap = pme->pme_order - 1;
1231

    
1232
    /* Add periodic overlap in z */
1233
    for(ix=0; ix<pme->pmegrid_nx; ix++)
1234
    {
1235
        for(iy=0; iy<pme->pmegrid_ny; iy++)
1236
        {
1237
            for(iz=0; iz<overlap; iz++)
1238
            {
1239
                pmegrid[(ix*pny+iy)*pnz+iz] +=
1240
                    pmegrid[(ix*pny+iy)*pnz+nz+iz];
1241
            }
1242
        }
1243
    }
1244

    
1245
    if (pme->nnodes_minor == 1)
1246
    {
1247
       for(ix=0; ix<pme->pmegrid_nx; ix++)
1248
       {
1249
           for(iy=0; iy<overlap; iy++)
1250
           {
1251
               for(iz=0; iz<nz; iz++)
1252
               {
1253
                   pmegrid[(ix*pny+iy)*pnz+iz] +=
1254
                       pmegrid[(ix*pny+ny+iy)*pnz+iz];
1255
               }
1256
           }
1257
       }
1258
    }
1259

    
1260
    if (pme->nnodes_major == 1)
1261
    {
1262
        ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1263

    
1264
        for(ix=0; ix<overlap; ix++)
1265
        {
1266
            for(iy=0; iy<ny_x; iy++)
1267
            {
1268
                for(iz=0; iz<nz; iz++)
1269
                {
1270
                    pmegrid[(ix*pny+iy)*pnz+iz] +=
1271
                        pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1272
                }
1273
            }
1274
        }
1275
    }
1276
}
1277

    
1278

    
1279
static void
1280
unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1281
{
1282
    int     nx,ny,nz,pnx,pny,pnz,ny_x,overlap,ix;
1283

    
1284
    nx = pme->nkx;
1285
    ny = pme->nky;
1286
    nz = pme->nkz;
1287

    
1288
    pnx = pme->pmegrid_nx;
1289
    pny = pme->pmegrid_ny;
1290
    pnz = pme->pmegrid_nz;
1291

    
1292
    overlap = pme->pme_order - 1;
1293

    
1294
    if (pme->nnodes_major == 1)
1295
    {
1296
        ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1297

    
1298
        for(ix=0; ix<overlap; ix++)
1299
        {
1300
            int iy,iz;
1301

    
1302
            for(iy=0; iy<ny_x; iy++)
1303
            {
1304
                for(iz=0; iz<nz; iz++)
1305
                {
1306
                    pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1307
                        pmegrid[(ix*pny+iy)*pnz+iz];
1308
                }
1309
            }
1310
        }
1311
    }
1312

    
1313
    if (pme->nnodes_minor == 1)
1314
    {
1315
#pragma omp parallel for num_threads(pme->nthread) schedule(static)
1316
       for(ix=0; ix<pme->pmegrid_nx; ix++)
1317
       {
1318
           int iy,iz;
1319

    
1320
           for(iy=0; iy<overlap; iy++)
1321
           {
1322
               for(iz=0; iz<nz; iz++)
1323
               {
1324
                   pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1325
                       pmegrid[(ix*pny+iy)*pnz+iz];
1326
               }
1327
           }
1328
       }
1329
    }
1330

    
1331
    /* Copy periodic overlap in z */
1332
#pragma omp parallel for num_threads(pme->nthread) schedule(static)
1333
    for(ix=0; ix<pme->pmegrid_nx; ix++)
1334
    {
1335
        int iy,iz;
1336

    
1337
        for(iy=0; iy<pme->pmegrid_ny; iy++)
1338
        {
1339
            for(iz=0; iz<overlap; iz++)
1340
            {
1341
                pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1342
                    pmegrid[(ix*pny+iy)*pnz+iz];
1343
            }
1344
        }
1345
    }
1346
}
1347

    
1348
static void clear_grid(int nx,int ny,int nz,real *grid,
1349
                       ivec fs,int *flag,
1350
                       int fx,int fy,int fz,
1351
                       int order)
1352
{
1353
    int nc,ncz;
1354
    int fsx,fsy,fsz,gx,gy,gz,g0x,g0y,x,y,z;
1355
    int flind;
1356

    
1357
    nc  = 2 + (order - 2)/FLBS;
1358
    ncz = 2 + (order - 2)/FLBSZ;
1359

    
1360
    for(fsx=fx; fsx<fx+nc; fsx++)
1361
    {
1362
        for(fsy=fy; fsy<fy+nc; fsy++)
1363
        {
1364
            for(fsz=fz; fsz<fz+ncz; fsz++)
1365
            {
1366
                flind = (fsx*fs[YY] + fsy)*fs[ZZ] + fsz;
1367
                if (flag[flind] == 0)
1368
                {
1369
                    gx = fsx*FLBS;
1370
                    gy = fsy*FLBS;
1371
                    gz = fsz*FLBSZ;
1372
                    g0x = (gx*ny + gy)*nz + gz;
1373
                    for(x=0; x<FLBS; x++)
1374
                    {
1375
                        g0y = g0x;
1376
                        for(y=0; y<FLBS; y++)
1377
                        {
1378
                            for(z=0; z<FLBSZ; z++)
1379
                            {
1380
                                grid[g0y+z] = 0;
1381
                            }
1382
                            g0y += nz;
1383
                        }
1384
                        g0x += ny*nz;
1385
                    }
1386

    
1387
                    flag[flind] = 1;
1388
                }
1389
            }
1390
        }
1391
    }
1392
}
1393

    
1394
/* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1395
#define DO_BSPLINE(order)                            \
1396
for(ithx=0; (ithx<order); ithx++)                    \
1397
{                                                    \
1398
    index_x = (i0+ithx)*pny*pnz;                     \
1399
    valx    = qn*thx[ithx];                          \
1400
                                                     \
1401
    for(ithy=0; (ithy<order); ithy++)                \
1402
    {                                                \
1403
        valxy    = valx*thy[ithy];                   \
1404
        index_xy = index_x+(j0+ithy)*pnz;            \
1405
                                                     \
1406
        for(ithz=0; (ithz<order); ithz++)            \
1407
        {                                            \
1408
            index_xyz        = index_xy+(k0+ithz);   \
1409
            grid[index_xyz] += valxy*thz[ithz];      \
1410
        }                                            \
1411
    }                                                \
1412
}
1413

    
1414

    
1415
static void spread_q_bsplines_thread(pmegrid_t *pmegrid,
1416
                                     pme_atomcomm_t *atc, splinedata_t *spline,
1417
                                     pme_spline_work_t *work)
1418
{
1419

    
1420
    /* spread charges from home atoms to local grid */
1421
    real     *grid;
1422
    pme_overlap_t *ol;
1423
    int      b,i,nn,n,ithx,ithy,ithz,i0,j0,k0;
1424
    int *    idxptr;
1425
    int      order,norder,index_x,index_xy,index_xyz;
1426
    real     valx,valxy,qn;
1427
    real     *thx,*thy,*thz;
1428
    int      localsize, bndsize;
1429
    int      pnx,pny,pnz,ndatatot;
1430
    int      offx,offy,offz;
1431

    
1432
    pnx = pmegrid->n[XX];
1433
    pny = pmegrid->n[YY];
1434
    pnz = pmegrid->n[ZZ];
1435

    
1436
    offx = pmegrid->offset[XX];
1437
    offy = pmegrid->offset[YY];
1438
    offz = pmegrid->offset[ZZ];
1439

    
1440
    ndatatot = pnx*pny*pnz;
1441
    grid = pmegrid->grid;
1442
    for(i=0;i<ndatatot;i++)
1443
    {
1444
        grid[i] = 0;
1445
    }
1446

    
1447
    order = pmegrid->order;
1448

    
1449
    for(nn=0; nn<spline->n; nn++)
1450
    {
1451
        n  = spline->ind[nn];
1452
        qn = atc->q[n];
1453

    
1454
        if (qn != 0)
1455
        {
1456
            idxptr = atc->idx[n];
1457
            norder = nn*order;
1458

    
1459
            i0   = idxptr[XX] - offx;
1460
            j0   = idxptr[YY] - offy;
1461
            k0   = idxptr[ZZ] - offz;
1462

    
1463
            thx = spline->theta[XX] + norder;
1464
            thy = spline->theta[YY] + norder;
1465
            thz = spline->theta[ZZ] + norder;
1466

    
1467
            switch (order) {
1468
            case 4:
1469
#ifdef PME_SSE
1470
#ifdef PME_SSE_UNALIGNED
1471
#define PME_SPREAD_SSE_ORDER4
1472
#else
1473
#define PME_SPREAD_SSE_ALIGNED
1474
#define PME_ORDER 4
1475
#endif
1476
#include "pme_sse_single.h"
1477
#else
1478
                DO_BSPLINE(4);
1479
#endif
1480
                break;
1481
            case 5:
1482
#ifdef PME_SSE
1483
#define PME_SPREAD_SSE_ALIGNED
1484
#define PME_ORDER 5
1485
#include "pme_sse_single.h"
1486
#else
1487
                DO_BSPLINE(5);
1488
#endif
1489
                break;
1490
            default:
1491
                DO_BSPLINE(order);
1492
                break;
1493
            }
1494
        }
1495
    }
1496
}
1497

    
1498
static void set_grid_alignment(int *pmegrid_nz,int pme_order)
1499
{
1500
#ifdef PME_SSE
1501
    if (pme_order == 5
1502
#ifndef PME_SSE_UNALIGNED
1503
        || pme_order == 4
1504
#endif
1505
        )
1506
    {
1507
        /* Round nz up to a multiple of 4 to ensure alignment */
1508
        *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1509
    }
1510
#endif
1511
}
1512

    
1513
static void set_gridsize_alignment(int *gridsize,int pme_order)
1514
{
1515
#ifdef PME_SSE
1516
#ifndef PME_SSE_UNALIGNED
1517
    if (pme_order == 4)
1518
    {
1519
        /* Add extra elements to ensured aligned operations do not go
1520
         * beyond the allocated grid size.
1521
         * Note that for pme_order=5, the pme grid z-size alignment
1522
         * ensures that we will not go beyond the grid size.
1523
         */
1524
         *gridsize += 4;
1525
    }
1526
#endif
1527
#endif
1528
}
1529

    
1530
static void pmegrid_init(pmegrid_t *grid,
1531
                         int cx, int cy, int cz,
1532
                         int x0, int y0, int z0,
1533
                         int x1, int y1, int z1,
1534
                         gmx_bool set_alignment,
1535
                         int pme_order,
1536
                         real *ptr)
1537
{
1538
    int nz,gridsize;
1539

    
1540
    grid->ci[XX] = cx;
1541
    grid->ci[YY] = cy;
1542
    grid->ci[ZZ] = cz;
1543
    grid->offset[XX] = x0;
1544
    grid->offset[YY] = y0;
1545
    grid->offset[ZZ] = z0;
1546
    grid->n[XX]      = x1 - x0 + pme_order - 1;
1547
    grid->n[YY]      = y1 - y0 + pme_order - 1;
1548
    grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1549

    
1550
    nz = grid->n[ZZ];
1551
    set_grid_alignment(&nz,pme_order);
1552
    if (set_alignment)
1553
    {
1554
        grid->n[ZZ] = nz;
1555
    }
1556
    else if (nz != grid->n[ZZ])
1557
    {
1558
        gmx_incons("pmegrid_init call with an unaligned z size");
1559
    }
1560

    
1561
    grid->order = pme_order;
1562
    if (ptr == NULL)
1563
    {
1564
        gridsize = grid->n[XX]*grid->n[YY]*grid->n[ZZ];
1565
        set_gridsize_alignment(&gridsize,pme_order);
1566
        snew_aligned(grid->grid,gridsize,16);
1567
    }
1568
    else
1569
    {
1570
        grid->grid = ptr;
1571
    }
1572
}
1573

    
1574
static int div_round_up(int enumerator,int denominator)
1575
{
1576
    return (enumerator + denominator - 1)/denominator;
1577
}
1578

    
1579
static void make_subgrid_division(const ivec n,int ovl,int nthread,
1580
                                  ivec nsub)
1581
{
1582
    int gsize_opt,gsize;
1583
    int nsx,nsy,nsz;
1584
    char *env;
1585

    
1586
    gsize_opt = -1;
1587
    for(nsx=1; nsx<=nthread; nsx++)
1588
    {
1589
        if (nthread % nsx == 0)
1590
        {
1591
            for(nsy=1; nsy<=nthread; nsy++)
1592
            {
1593
                if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1594
                {
1595
                    nsz = nthread/(nsx*nsy);
1596

    
1597
                    /* Determine the number of grid points per thread */
1598
                    gsize =
1599
                        (div_round_up(n[XX],nsx) + ovl)*
1600
                        (div_round_up(n[YY],nsy) + ovl)*
1601
                        (div_round_up(n[ZZ],nsz) + ovl);
1602

    
1603
                    /* Minimize the number of grids points per thread
1604
                     * and, secondarily, the number of cuts in minor dimensions.
1605
                     */
1606
                    if (gsize_opt == -1 ||
1607
                        gsize < gsize_opt ||
1608
                        (gsize == gsize_opt &&
1609
                         (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1610
                    {
1611
                        nsub[XX] = nsx;
1612
                        nsub[YY] = nsy;
1613
                        nsub[ZZ] = nsz;
1614
                        gsize_opt = gsize;
1615
                    }
1616
                }
1617
            }
1618
        }
1619
    }
1620

    
1621
    env = getenv("GMX_PME_THREAD_DIVISION");
1622
    if (env != NULL)
1623
    {
1624
        sscanf(env,"%d %d %d",&nsub[XX],&nsub[YY],&nsub[ZZ]);
1625
    }
1626

    
1627
    if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1628
    {
1629
        gmx_fatal(FARGS,"PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)",nsub[XX],nsub[YY],nsub[ZZ],nthread);
1630
    }
1631
}
1632

    
1633
static void pmegrids_init(pmegrids_t *grids,
1634
                          int nx,int ny,int nz,int nz_base,
1635
                          int pme_order,
1636
                          int nthread,
1637
                          int overlap_x,
1638
                          int overlap_y)
1639
{
1640
    ivec n,n_base,g0,g1;
1641
    int t,x,y,z,d,i,tfac;
1642
    int max_comm_lines;
1643

    
1644
    n[XX] = nx - (pme_order - 1);
1645
    n[YY] = ny - (pme_order - 1);
1646
    n[ZZ] = nz - (pme_order - 1);
1647

    
1648
    copy_ivec(n,n_base);
1649
    n_base[ZZ] = nz_base;
1650

    
1651
    pmegrid_init(&grids->grid,0,0,0,0,0,0,n[XX],n[YY],n[ZZ],FALSE,pme_order,
1652
                 NULL);
1653

    
1654
    grids->nthread = nthread;
1655

    
1656
    make_subgrid_division(n_base,pme_order-1,grids->nthread,grids->nc);
1657

    
1658
    if (grids->nthread > 1)
1659
    {
1660
        ivec nst;
1661
        int gridsize;
1662
        real *grid_all;
1663

    
1664
        for(d=0; d<DIM; d++)
1665
        {
1666
            nst[d] = div_round_up(n[d],grids->nc[d]) + pme_order - 1;
1667
        }
1668
        set_grid_alignment(&nst[ZZ],pme_order);
1669

    
1670
        if (debug)
1671
        {
1672
            fprintf(debug,"pmegrid thread local division: %d x %d x %d\n",
1673
                    grids->nc[XX],grids->nc[YY],grids->nc[ZZ]);
1674
            fprintf(debug,"pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1675
                    nx,ny,nz,
1676
                    nst[XX],nst[YY],nst[ZZ]);
1677
        }
1678

    
1679
        snew(grids->grid_th,grids->nthread);
1680
        t = 0;
1681
        gridsize = nst[XX]*nst[YY]*nst[ZZ];
1682
        set_gridsize_alignment(&gridsize,pme_order);
1683
        snew_aligned(grid_all,
1684
                     grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1685
                     16);
1686

    
1687
        for(x=0; x<grids->nc[XX]; x++)
1688
        {
1689
            for(y=0; y<grids->nc[YY]; y++)
1690
            {
1691
                for(z=0; z<grids->nc[ZZ]; z++)
1692
                {
1693
                    pmegrid_init(&grids->grid_th[t],
1694
                                 x,y,z,
1695
                                 (n[XX]*(x  ))/grids->nc[XX],
1696
                                 (n[YY]*(y  ))/grids->nc[YY],
1697
                                 (n[ZZ]*(z  ))/grids->nc[ZZ],
1698
                                 (n[XX]*(x+1))/grids->nc[XX],
1699
                                 (n[YY]*(y+1))/grids->nc[YY],
1700
                                 (n[ZZ]*(z+1))/grids->nc[ZZ],
1701
                                 TRUE,
1702
                                 pme_order,
1703
                                 grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1704
                    t++;
1705
                }
1706
            }
1707
        }
1708
    }
1709

    
1710
    snew(grids->g2t,DIM);
1711
    tfac = 1;
1712
    for(d=DIM-1; d>=0; d--)
1713
    {
1714
        snew(grids->g2t[d],n[d]);
1715
        t = 0;
1716
        for(i=0; i<n[d]; i++)
1717
        {
1718
            /* The second check should match the parameters
1719
             * of the pmegrid_init call above.
1720
             */
1721
            while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1722
            {
1723
                t++;
1724
            }
1725
            grids->g2t[d][i] = t*tfac;
1726
        }
1727

    
1728
        tfac *= grids->nc[d];
1729

    
1730
        switch (d)
1731
        {
1732
        case XX: max_comm_lines = overlap_x;     break;
1733
        case YY: max_comm_lines = overlap_y;     break;
1734
        case ZZ: max_comm_lines = pme_order - 1; break;
1735
        }
1736
        grids->nthread_comm[d] = 0;
1737
        while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines)
1738
        {
1739
            grids->nthread_comm[d]++;
1740
        }
1741
        if (debug != NULL)
1742
        {
1743
            fprintf(debug,"pmegrid thread grid communication range in %c: %d\n",
1744
                    'x'+d,grids->nthread_comm[d]);
1745
        }
1746
        /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1747
         * work, but this is not a problematic restriction.
1748
         */
1749
        if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1750
        {
1751
            gmx_fatal(FARGS,"Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME",grids->nthread);
1752
        }
1753
    }
1754
}
1755

    
1756

    
1757
static void pmegrids_destroy(pmegrids_t *grids)
1758
{
1759
    int t;
1760

    
1761
    if (grids->grid.grid != NULL)
1762
    {
1763
        sfree(grids->grid.grid);
1764

    
1765
        if (grids->nthread > 0)
1766
        {
1767
            for(t=0; t<grids->nthread; t++)
1768
            {
1769
                sfree(grids->grid_th[t].grid);
1770
            }
1771
            sfree(grids->grid_th);
1772
        }
1773
    }
1774
}
1775

    
1776

    
1777
static void realloc_work(pme_work_t *work,int nkx)
1778
{
1779
    if (nkx > work->nalloc)
1780
    {
1781
        work->nalloc = nkx;
1782
        srenew(work->mhx  ,work->nalloc);
1783
        srenew(work->mhy  ,work->nalloc);
1784
        srenew(work->mhz  ,work->nalloc);
1785
        srenew(work->m2   ,work->nalloc);
1786
        /* Allocate an aligned pointer for SSE operations, including 3 extra
1787
         * elements at the end since SSE operates on 4 elements at a time.
1788
         */
1789
        sfree_aligned(work->denom);
1790
        sfree_aligned(work->tmp1);
1791
        sfree_aligned(work->eterm);
1792
        snew_aligned(work->denom,work->nalloc+3,16);
1793
        snew_aligned(work->tmp1 ,work->nalloc+3,16);
1794
        snew_aligned(work->eterm,work->nalloc+3,16);
1795
        srenew(work->m2inv,work->nalloc);
1796
    }
1797
}
1798

    
1799

    
1800
static void free_work(pme_work_t *work)
1801
{
1802
    sfree(work->mhx);
1803
    sfree(work->mhy);
1804
    sfree(work->mhz);
1805
    sfree(work->m2);
1806
    sfree_aligned(work->denom);
1807
    sfree_aligned(work->tmp1);
1808
    sfree_aligned(work->eterm);
1809
    sfree(work->m2inv);
1810
}
1811

    
1812

    
1813
#ifdef PME_SSE
1814
    /* Calculate exponentials through SSE in float precision */
1815
inline static void calc_exponentials(int start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1816
{
1817
    {
1818
        const __m128 two = _mm_set_ps(2.0f,2.0f,2.0f,2.0f);
1819
        __m128 f_sse;
1820
        __m128 lu;
1821
        __m128 tmp_d1,d_inv,tmp_r,tmp_e;
1822
        int kx;
1823
        f_sse = _mm_load1_ps(&f);
1824
        for(kx=0; kx<end; kx+=4)
1825
        {
1826
            tmp_d1   = _mm_load_ps(d_aligned+kx);
1827
            lu       = _mm_rcp_ps(tmp_d1);
1828
            d_inv    = _mm_mul_ps(lu,_mm_sub_ps(two,_mm_mul_ps(lu,tmp_d1)));
1829
            tmp_r    = _mm_load_ps(r_aligned+kx);
1830
            tmp_r    = gmx_mm_exp_ps(tmp_r);
1831
            tmp_e    = _mm_mul_ps(f_sse,d_inv);
1832
            tmp_e    = _mm_mul_ps(tmp_e,tmp_r);
1833
            _mm_store_ps(e_aligned+kx,tmp_e);
1834
        }
1835
    }
1836
}
1837
#else
1838
inline static void calc_exponentials(int start, int end, real f, real *d, real *r, real *e)
1839
{
1840
    int kx;
1841
    for(kx=start; kx<end; kx++)
1842
    {
1843
        d[kx] = 1.0/d[kx];
1844
    }
1845
    for(kx=start; kx<end; kx++)
1846
    {
1847
        r[kx] = exp(r[kx]);
1848
    }
1849
    for(kx=start; kx<end; kx++)
1850
    {
1851
        e[kx] = f*r[kx]*d[kx];
1852
    }
1853
}
1854
#endif
1855

    
1856

    
1857
static int solve_pme_yzx(gmx_pme_t pme,t_complex *grid,
1858
                         real ewaldcoeff,real vol,
1859
                         gmx_bool bEnerVir,
1860
                         int nthread,int thread)
1861
{
1862
    /* do recip sum over local cells in grid */
1863
    /* y major, z middle, x minor or continuous */
1864
    t_complex *p0;
1865
    int     kx,ky,kz,maxkx,maxky,maxkz;
1866
    int     nx,ny,nz,iyz0,iyz1,iyz,iy,iz,kxstart,kxend;
1867
    real    mx,my,mz;
1868
    real    factor=M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1869
    real    ets2,struct2,vfactor,ets2vf;
1870
    real    d1,d2,energy=0;
1871
    real    by,bz;
1872
    real    virxx=0,virxy=0,virxz=0,viryy=0,viryz=0,virzz=0;
1873
    real    rxx,ryx,ryy,rzx,rzy,rzz;
1874
    pme_work_t *work;
1875
    real    *mhx,*mhy,*mhz,*m2,*denom,*tmp1,*eterm,*m2inv;
1876
    real    mhxk,mhyk,mhzk,m2k;
1877
    real    corner_fac;
1878
    ivec    complex_order;
1879
    ivec    local_ndata,local_offset,local_size;
1880
    real    elfac;
1881

    
1882
    elfac = ONE_4PI_EPS0/pme->epsilon_r;
1883

    
1884
    nx = pme->nkx;
1885
    ny = pme->nky;
1886
    nz = pme->nkz;
1887

    
1888
    /* Dimensions should be identical for A/B grid, so we just use A here */
1889
    gmx_parallel_3dfft_complex_limits(pme->pfft_setupA,
1890
                                      complex_order,
1891
                                      local_ndata,
1892
                                      local_offset,
1893
                                      local_size);
1894

    
1895
    rxx = pme->recipbox[XX][XX];
1896
    ryx = pme->recipbox[YY][XX];
1897
    ryy = pme->recipbox[YY][YY];
1898
    rzx = pme->recipbox[ZZ][XX];
1899
    rzy = pme->recipbox[ZZ][YY];
1900
    rzz = pme->recipbox[ZZ][ZZ];
1901

    
1902
    maxkx = (nx+1)/2;
1903
    maxky = (ny+1)/2;
1904
    maxkz = nz/2+1;
1905

    
1906
    work = &pme->work[thread];
1907
    mhx   = work->mhx;
1908
    mhy   = work->mhy;
1909
    mhz   = work->mhz;
1910
    m2    = work->m2;
1911
    denom = work->denom;
1912
    tmp1  = work->tmp1;
1913
    eterm = work->eterm;
1914
    m2inv = work->m2inv;
1915

    
1916
    iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
1917
    iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
1918

    
1919
    for(iyz=iyz0; iyz<iyz1; iyz++)
1920
    {
1921
        iy = iyz/local_ndata[ZZ];
1922
        iz = iyz - iy*local_ndata[ZZ];
1923

    
1924
        ky = iy + local_offset[YY];
1925

    
1926
        if (ky < maxky)
1927
        {
1928
            my = ky;
1929
        }
1930
        else
1931
        {
1932
            my = (ky - ny);
1933
        }
1934

    
1935
        by = M_PI*vol*pme->bsp_mod[YY][ky];
1936

    
1937
        kz = iz + local_offset[ZZ];
1938

    
1939
        mz = kz;
1940

    
1941
        bz = pme->bsp_mod[ZZ][kz];
1942

    
1943
        /* 0.5 correction for corner points */
1944
        corner_fac = 1;
1945
        if (kz == 0 || kz == (nz+1)/2)
1946
        {
1947
            corner_fac = 0.5;
1948
        }
1949

    
1950
        p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
1951

    
1952
        /* We should skip the k-space point (0,0,0) */
1953
        if (local_offset[XX] > 0 || ky > 0 || kz > 0)
1954
        {
1955
            kxstart = local_offset[XX];
1956
        }
1957
        else
1958
        {
1959
            kxstart = local_offset[XX] + 1;
1960
            p0++;
1961
        }
1962
        kxend = local_offset[XX] + local_ndata[XX];
1963

    
1964
        if (bEnerVir)
1965
        {
1966
            /* More expensive inner loop, especially because of the storage
1967
             * of the mh elements in array's.
1968
             * Because x is the minor grid index, all mh elements
1969
             * depend on kx for triclinic unit cells.
1970
             */
1971

    
1972
                /* Two explicit loops to avoid a conditional inside the loop */
1973
            for(kx=kxstart; kx<maxkx; kx++)
1974
            {
1975
                mx = kx;
1976

    
1977
                mhxk      = mx * rxx;
1978
                mhyk      = mx * ryx + my * ryy;
1979
                mhzk      = mx * rzx + my * rzy + mz * rzz;
1980
                m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
1981
                mhx[kx]   = mhxk;
1982
                mhy[kx]   = mhyk;
1983
                mhz[kx]   = mhzk;
1984
                m2[kx]    = m2k;
1985
                denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
1986
                tmp1[kx]  = -factor*m2k;
1987
            }
1988

    
1989
            for(kx=maxkx; kx<kxend; kx++)
1990
            {
1991
                mx = (kx - nx);
1992

    
1993
                mhxk      = mx * rxx;
1994
                mhyk      = mx * ryx + my * ryy;
1995
                mhzk      = mx * rzx + my * rzy + mz * rzz;
1996
                m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
1997
                mhx[kx]   = mhxk;
1998
                mhy[kx]   = mhyk;
1999
                mhz[kx]   = mhzk;
2000
                m2[kx]    = m2k;
2001
                denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2002
                tmp1[kx]  = -factor*m2k;
2003
            }
2004

    
2005
            for(kx=kxstart; kx<kxend; kx++)
2006
            {
2007
                m2inv[kx] = 1.0/m2[kx];
2008
            }
2009

    
2010
            calc_exponentials(kxstart,kxend,elfac,denom,tmp1,eterm);
2011

    
2012
            for(kx=kxstart; kx<kxend; kx++,p0++)
2013
            {
2014
                d1      = p0->re;
2015
                d2      = p0->im;
2016

    
2017
                p0->re  = d1*eterm[kx];
2018
                p0->im  = d2*eterm[kx];
2019

    
2020
                struct2 = 2.0*(d1*d1+d2*d2);
2021

    
2022
                tmp1[kx] = eterm[kx]*struct2;
2023
            }
2024

    
2025
            for(kx=kxstart; kx<kxend; kx++)
2026
            {
2027
                ets2     = corner_fac*tmp1[kx];
2028
                vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2029
                energy  += ets2;
2030

    
2031
                ets2vf   = ets2*vfactor;
2032
                virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2033
                virxy   += ets2vf*mhx[kx]*mhy[kx];
2034
                virxz   += ets2vf*mhx[kx]*mhz[kx];
2035
                viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2036
                viryz   += ets2vf*mhy[kx]*mhz[kx];
2037
                virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2038
            }
2039
        }
2040
        else
2041
        {
2042
            /* We don't need to calculate the energy and the virial.
2043
             * In this case the triclinic overhead is small.
2044
             */
2045

    
2046
            /* Two explicit loops to avoid a conditional inside the loop */
2047

    
2048
            for(kx=kxstart; kx<maxkx; kx++)
2049
            {
2050
                mx = kx;
2051

    
2052
                mhxk      = mx * rxx;
2053
                mhyk      = mx * ryx + my * ryy;
2054
                mhzk      = mx * rzx + my * rzy + mz * rzz;
2055
                m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2056
                denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2057
                tmp1[kx]  = -factor*m2k;
2058
            }
2059

    
2060
            for(kx=maxkx; kx<kxend; kx++)
2061
            {
2062
                mx = (kx - nx);
2063

    
2064
                mhxk      = mx * rxx;
2065
                mhyk      = mx * ryx + my * ryy;
2066
                mhzk      = mx * rzx + my * rzy + mz * rzz;
2067
                m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2068
                denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2069
                tmp1[kx]  = -factor*m2k;
2070
            }
2071

    
2072
            calc_exponentials(kxstart,kxend,elfac,denom,tmp1,eterm);
2073

    
2074
            for(kx=kxstart; kx<kxend; kx++,p0++)
2075
            {
2076
                d1      = p0->re;
2077
                d2      = p0->im;
2078

    
2079
                p0->re  = d1*eterm[kx];
2080
                p0->im  = d2*eterm[kx];
2081
            }
2082
        }
2083
    }
2084

    
2085
    if (bEnerVir)
2086
    {
2087
        /* Update virial with local values.
2088
         * The virial is symmetric by definition.
2089
         * this virial seems ok for isotropic scaling, but I'm
2090
         * experiencing problems on semiisotropic membranes.
2091
         * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2092
         */
2093
        work->vir[XX][XX] = 0.25*virxx;
2094
        work->vir[YY][YY] = 0.25*viryy;
2095
        work->vir[ZZ][ZZ] = 0.25*virzz;
2096
        work->vir[XX][YY] = work->vir[YY][XX] = 0.25*virxy;
2097
        work->vir[XX][ZZ] = work->vir[ZZ][XX] = 0.25*virxz;
2098
        work->vir[YY][ZZ] = work->vir[ZZ][YY] = 0.25*viryz;
2099

    
2100
        /* This energy should be corrected for a charged system */
2101
        work->energy = 0.5*energy;
2102
    }
2103

    
2104
    /* Return the loop count */
2105
    return local_ndata[YY]*local_ndata[XX];
2106
}
2107

    
2108
static void get_pme_ener_vir(const gmx_pme_t pme,int nthread,
2109
                             real *mesh_energy,matrix vir)
2110
{
2111
    /* This function sums output over threads
2112
     * and should therefore only be called after thread synchronization.
2113
     */
2114
    int thread;
2115

    
2116
    *mesh_energy = pme->work[0].energy;
2117
    copy_mat(pme->work[0].vir,vir);
2118

    
2119
    for(thread=1; thread<nthread; thread++)
2120
    {
2121
        *mesh_energy += pme->work[thread].energy;
2122
        m_add(vir,pme->work[thread].vir,vir);
2123
    }
2124
}
2125

    
2126
#define DO_FSPLINE(order)                      \
2127
for(ithx=0; (ithx<order); ithx++)              \
2128
{                                              \
2129
    index_x = (i0+ithx)*pny*pnz;               \
2130
    tx      = thx[ithx];                       \
2131
    dx      = dthx[ithx];                      \
2132
                                               \
2133
    for(ithy=0; (ithy<order); ithy++)          \
2134
    {                                          \
2135
        index_xy = index_x+(j0+ithy)*pnz;      \
2136
        ty       = thy[ithy];                  \
2137
        dy       = dthy[ithy];                 \
2138
        fxy1     = fz1 = 0;                    \
2139
                                               \
2140
        for(ithz=0; (ithz<order); ithz++)      \
2141
        {                                      \
2142
            gval  = grid[index_xy+(k0+ithz)];  \
2143
            fxy1 += thz[ithz]*gval;            \
2144
            fz1  += dthz[ithz]*gval;           \
2145
        }                                      \
2146
        fx += dx*ty*fxy1;                      \
2147
        fy += tx*dy*fxy1;                      \
2148
        fz += tx*ty*fz1;                       \
2149
    }                                          \
2150
}
2151

    
2152

    
2153
static void gather_f_bsplines(gmx_pme_t pme,real *grid,
2154
                              gmx_bool bClearF,pme_atomcomm_t *atc,
2155
                              splinedata_t *spline,
2156
                              real scale)
2157
{
2158
    /* sum forces for local particles */
2159
    int     nn,n,ithx,ithy,ithz,i0,j0,k0;
2160
    int     index_x,index_xy;
2161
    int     nx,ny,nz,pnx,pny,pnz;
2162
    int *   idxptr;
2163
    real    tx,ty,dx,dy,qn;
2164
    real    fx,fy,fz,gval;
2165
    real    fxy1,fz1;
2166
    real    *thx,*thy,*thz,*dthx,*dthy,*dthz;
2167
    int     norder;
2168
    real    rxx,ryx,ryy,rzx,rzy,rzz;
2169
    int     order;
2170

    
2171
    pme_spline_work_t *work;
2172

    
2173
    work = pme->spline_work;
2174

    
2175
    order = pme->pme_order;
2176
    thx   = spline->theta[XX];
2177
    thy   = spline->theta[YY];
2178
    thz   = spline->theta[ZZ];
2179
    dthx  = spline->dtheta[XX];
2180
    dthy  = spline->dtheta[YY];
2181
    dthz  = spline->dtheta[ZZ];
2182
    nx    = pme->nkx;
2183
    ny    = pme->nky;
2184
    nz    = pme->nkz;
2185
    pnx   = pme->pmegrid_nx;
2186
    pny   = pme->pmegrid_ny;
2187
    pnz   = pme->pmegrid_nz;
2188

    
2189
    rxx   = pme->recipbox[XX][XX];
2190
    ryx   = pme->recipbox[YY][XX];
2191
    ryy   = pme->recipbox[YY][YY];
2192
    rzx   = pme->recipbox[ZZ][XX];
2193
    rzy   = pme->recipbox[ZZ][YY];
2194
    rzz   = pme->recipbox[ZZ][ZZ];
2195

    
2196
    for(nn=0; nn<spline->n; nn++)
2197
    {
2198
        n  = spline->ind[nn];
2199
        qn = scale*atc->q[n];
2200

    
2201
        if (bClearF)
2202
        {
2203
            atc->f[n][XX] = 0;
2204
            atc->f[n][YY] = 0;
2205
            atc->f[n][ZZ] = 0;
2206
        }
2207
        if (qn != 0)
2208
        {
2209
            fx     = 0;
2210
            fy     = 0;
2211
            fz     = 0;
2212
            idxptr = atc->idx[n];
2213
            norder = nn*order;
2214

    
2215
            i0   = idxptr[XX];
2216
            j0   = idxptr[YY];
2217
            k0   = idxptr[ZZ];
2218

    
2219
            /* Pointer arithmetic alert, next six statements */
2220
            thx  = spline->theta[XX] + norder;
2221
            thy  = spline->theta[YY] + norder;
2222
            thz  = spline->theta[ZZ] + norder;
2223
            dthx = spline->dtheta[XX] + norder;
2224
            dthy = spline->dtheta[YY] + norder;
2225
            dthz = spline->dtheta[ZZ] + norder;
2226

    
2227
            switch (order) {
2228
            case 4:
2229
#ifdef PME_SSE
2230
#ifdef PME_SSE_UNALIGNED
2231
#define PME_GATHER_F_SSE_ORDER4
2232
#else
2233
#define PME_GATHER_F_SSE_ALIGNED
2234
#define PME_ORDER 4
2235
#endif
2236
#include "pme_sse_single.h"
2237
#else
2238
                DO_FSPLINE(4);
2239
#endif
2240
                break;
2241
            case 5:
2242
#ifdef PME_SSE
2243
#define PME_GATHER_F_SSE_ALIGNED
2244
#define PME_ORDER 5
2245
#include "pme_sse_single.h"
2246
#else
2247
                DO_FSPLINE(5);
2248
#endif
2249
                break;
2250
            default:
2251
                DO_FSPLINE(order);
2252
                break;
2253
            }
2254

    
2255
            atc->f[n][XX] += -qn*( fx*nx*rxx );
2256
            atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2257
            atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2258
        }
2259
    }
2260
    /* Since the energy and not forces are interpolated
2261
     * the net force might not be exactly zero.
2262
     * This can be solved by also interpolating F, but
2263
     * that comes at a cost.
2264
     * A better hack is to remove the net force every
2265
     * step, but that must be done at a higher level
2266
     * since this routine doesn't see all atoms if running
2267
     * in parallel. Don't know how important it is?  EL 990726
2268
     */
2269
}
2270

    
2271

    
2272
static real gather_energy_bsplines(gmx_pme_t pme,real *grid,
2273
                                   pme_atomcomm_t *atc)
2274
{
2275
    splinedata_t *spline;
2276
    int     n,ithx,ithy,ithz,i0,j0,k0;
2277
    int     index_x,index_xy;
2278
    int *   idxptr;
2279
    real    energy,pot,tx,ty,qn,gval;
2280
    real    *thx,*thy,*thz;
2281
    int     norder;
2282
    int     order;
2283

    
2284
    spline = &atc->spline[0];
2285

    
2286
    order = pme->pme_order;
2287

    
2288
    energy = 0;
2289
    for(n=0; (n<atc->n); n++) {
2290
        qn      = atc->q[n];
2291

    
2292
        if (qn != 0) {
2293
            idxptr = atc->idx[n];
2294
            norder = n*order;
2295

    
2296
            i0   = idxptr[XX];
2297
            j0   = idxptr[YY];
2298
            k0   = idxptr[ZZ];
2299

    
2300
            /* Pointer arithmetic alert, next three statements */
2301
            thx  = spline->theta[XX] + norder;
2302
            thy  = spline->theta[YY] + norder;
2303
            thz  = spline->theta[ZZ] + norder;
2304

    
2305
            pot = 0;
2306
            for(ithx=0; (ithx<order); ithx++)
2307
            {
2308
                index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2309
                tx      = thx[ithx];
2310

    
2311
                for(ithy=0; (ithy<order); ithy++)
2312
                {
2313
                    index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2314
                    ty       = thy[ithy];
2315

    
2316
                    for(ithz=0; (ithz<order); ithz++)
2317
                    {
2318
                        gval  = grid[index_xy+(k0+ithz)];
2319
                        pot  += tx*ty*thz[ithz]*gval;
2320
                    }
2321

    
2322
                }
2323
            }
2324

    
2325
            energy += pot*qn;
2326
        }
2327
    }
2328

    
2329
    return energy;
2330
}
2331

    
2332
/* Macro to force loop unrolling by fixing order.
2333
 * This gives a significant performance gain.
2334
 */
2335
#define CALC_SPLINE(order)                     \
2336
{                                              \
2337
    int j,k,l;                                 \
2338
    real dr,div;                               \
2339
    real data[PME_ORDER_MAX];                  \
2340
    real ddata[PME_ORDER_MAX];                 \
2341
                                               \
2342
    for(j=0; (j<DIM); j++)                     \
2343
    {                                          \
2344
        dr  = xptr[j];                         \
2345
                                               \
2346
        /* dr is relative offset from lower cell limit */ \
2347
        data[order-1] = 0;                     \
2348
        data[1] = dr;                          \
2349
        data[0] = 1 - dr;                      \
2350
                                               \
2351
        for(k=3; (k<order); k++)               \
2352
        {                                      \
2353
            div = 1.0/(k - 1.0);               \
2354
            data[k-1] = div*dr*data[k-2];      \
2355
            for(l=1; (l<(k-1)); l++)           \
2356
            {                                  \
2357
                data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2358
                                   data[k-l-1]);                \
2359
            }                                  \
2360
            data[0] = div*(1-dr)*data[0];      \
2361
        }                                      \
2362
        /* differentiate */                    \
2363
        ddata[0] = -data[0];                   \
2364
        for(k=1; (k<order); k++)               \
2365
        {                                      \
2366
            ddata[k] = data[k-1] - data[k];    \
2367
        }                                      \
2368
                                               \
2369
        div = 1.0/(order - 1);                 \
2370
        data[order-1] = div*dr*data[order-2];  \
2371
        for(l=1; (l<(order-1)); l++)           \
2372
        {                                      \
2373
            data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2374
                               (order-l-dr)*data[order-l-1]); \
2375
        }                                      \
2376
        data[0] = div*(1 - dr)*data[0];        \
2377
                                               \
2378
        for(k=0; k<order; k++)                 \
2379
        {                                      \
2380
            theta[j][i*order+k]  = data[k];    \
2381
            dtheta[j][i*order+k] = ddata[k];   \
2382
        }                                      \
2383
    }                                          \
2384
}
2385

    
2386
void make_bsplines(splinevec theta,splinevec dtheta,int order,
2387
                   rvec fractx[],int nr,int ind[],real charge[],
2388
                   gmx_bool bFreeEnergy)
2389
{
2390
    /* construct splines for local atoms */
2391
    int  i,ii;
2392
    real *xptr;
2393

    
2394
    for(i=0; i<nr; i++)
2395
    {
2396
        /* With free energy we do not use the charge check.
2397
         * In most cases this will be more efficient than calling make_bsplines
2398
         * twice, since usually more than half the particles have charges.
2399
         */
2400
        ii = ind[i];
2401
        if (bFreeEnergy || charge[ii] != 0.0) {
2402
            xptr = fractx[ii];
2403
            switch(order) {
2404
            case 4:  CALC_SPLINE(4);     break;
2405
            case 5:  CALC_SPLINE(5);     break;
2406
            default: CALC_SPLINE(order); break;
2407
            }
2408
        }
2409
    }
2410
}
2411

    
2412

    
2413
void make_dft_mod(real *mod,real *data,int ndata)
2414
{
2415
  int i,j;
2416
  real sc,ss,arg;
2417

    
2418
  for(i=0;i<ndata;i++) {
2419
    sc=ss=0;
2420
    for(j=0;j<ndata;j++) {
2421
      arg=(2.0*M_PI*i*j)/ndata;
2422
      sc+=data[j]*cos(arg);
2423
      ss+=data[j]*sin(arg);
2424
    }
2425
    mod[i]=sc*sc+ss*ss;
2426
  }
2427
  for(i=0;i<ndata;i++)
2428
    if(mod[i]<1e-7)
2429
      mod[i]=(mod[i-1]+mod[i+1])*0.5;
2430
}
2431

    
2432

    
2433
static void make_bspline_moduli(splinevec bsp_mod,
2434
                                int nx,int ny,int nz,int order)
2435
{
2436
  int nmax=max(nx,max(ny,nz));
2437
  real *data,*ddata,*bsp_data;
2438
  int i,k,l;
2439
  real div;
2440

    
2441
  snew(data,order);
2442
  snew(ddata,order);
2443
  snew(bsp_data,nmax);
2444

    
2445
  data[order-1]=0;
2446
  data[1]=0;
2447
  data[0]=1;
2448

    
2449
  for(k=3;k<order;k++) {
2450
    div=1.0/(k-1.0);
2451
    data[k-1]=0;
2452
    for(l=1;l<(k-1);l++)
2453
      data[k-l-1]=div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2454
    data[0]=div*data[0];
2455
  }
2456
  /* differentiate */
2457
  ddata[0]=-data[0];
2458
  for(k=1;k<order;k++)
2459
    ddata[k]=data[k-1]-data[k];
2460
  div=1.0/(order-1);
2461
  data[order-1]=0;
2462
  for(l=1;l<(order-1);l++)
2463
    data[order-l-1]=div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2464
  data[0]=div*data[0];
2465

    
2466
  for(i=0;i<nmax;i++)
2467
    bsp_data[i]=0;
2468
  for(i=1;i<=order;i++)
2469
    bsp_data[i]=data[i-1];
2470

    
2471
  make_dft_mod(bsp_mod[XX],bsp_data,nx);
2472
  make_dft_mod(bsp_mod[YY],bsp_data,ny);
2473
  make_dft_mod(bsp_mod[ZZ],bsp_data,nz);
2474

    
2475
  sfree(data);
2476
  sfree(ddata);
2477
  sfree(bsp_data);
2478
}
2479

    
2480

    
2481
/* Return the P3M optimal influence function */
2482
static double do_p3m_influence(double z, int order)
2483
{
2484
    double z2,z4;
2485

    
2486
    z2 = z*z;
2487
    z4 = z2*z2;
2488

    
2489
    /* The formula and most constants can be found in:
2490
     * Ballenegger et al., JCTC 8, 936 (2012)
2491
     */
2492
    switch(order)
2493
    {
2494
    case 2:
2495
        return 1.0 - 2.0*z2/3.0;
2496
        break;
2497
    case 3:
2498
        return 1.0 - z2 + 2.0*z4/15.0;
2499
        break;
2500
    case 4:
2501
        return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2502
        break;
2503
    case 5:
2504
        return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2505
        break;
2506
    case 6:
2507
        return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2508
        break;
2509
    case 7:
2510
        return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2511
    case 8:
2512
        return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2513
        break;
2514
    }
2515

    
2516
    return 0.0;
2517
}
2518

    
2519
/* Calculate the P3M B-spline moduli for one dimension */
2520
static void make_p3m_bspline_moduli_dim(real *bsp_mod,int n,int order)
2521
{
2522
    double zarg,zai,sinzai,infl;
2523
    int    maxk,i;
2524

    
2525
    if (order > 8)
2526
    {
2527
        gmx_fatal(FARGS,"The current P3M code only supports orders up to 8");
2528
    }
2529

    
2530
    zarg = M_PI/n;
2531

    
2532
    maxk = (n + 1)/2;
2533

    
2534
    for(i=-maxk; i<0; i++)
2535
    {
2536
        zai    = zarg*i;
2537
        sinzai = sin(zai);
2538
        infl   = do_p3m_influence(sinzai,order);
2539
        bsp_mod[n+i] = infl*infl*pow(sinzai/zai,-2.0*order);
2540
    }
2541
    bsp_mod[0] = 1.0;
2542
    for(i=1; i<maxk; i++)
2543
    {
2544
        zai    = zarg*i;
2545
        sinzai = sin(zai);
2546
        infl   = do_p3m_influence(sinzai,order);
2547
        bsp_mod[i] = infl*infl*pow(sinzai/zai,-2.0*order);
2548
    }
2549
}
2550

    
2551
/* Calculate the P3M B-spline moduli */
2552
static void make_p3m_bspline_moduli(splinevec bsp_mod,
2553
                                    int nx,int ny,int nz,int order)
2554
{
2555
    make_p3m_bspline_moduli_dim(bsp_mod[XX],nx,order);
2556
    make_p3m_bspline_moduli_dim(bsp_mod[YY],ny,order);
2557
    make_p3m_bspline_moduli_dim(bsp_mod[ZZ],nz,order);
2558
}
2559

    
2560

    
2561
static void setup_coordinate_communication(pme_atomcomm_t *atc)
2562
{
2563
  int nslab,n,i;
2564
  int fw,bw;
2565

    
2566
  nslab = atc->nslab;
2567

    
2568
  n = 0;
2569
  for(i=1; i<=nslab/2; i++) {
2570
    fw = (atc->nodeid + i) % nslab;
2571
    bw = (atc->nodeid - i + nslab) % nslab;
2572
    if (n < nslab - 1) {
2573
      atc->node_dest[n] = fw;
2574
      atc->node_src[n]  = bw;
2575
      n++;
2576
    }
2577
    if (n < nslab - 1) {
2578
      atc->node_dest[n] = bw;
2579
      atc->node_src[n]  = fw;
2580
      n++;
2581
    }
2582
  }
2583
}
2584

    
2585
int gmx_pme_destroy(FILE *log,gmx_pme_t *pmedata)
2586
{
2587
    int thread;
2588

    
2589
    if(NULL != log)
2590
    {
2591
        fprintf(log,"Destroying PME data structures.\n");
2592
    }
2593

    
2594
    sfree((*pmedata)->nnx);
2595
    sfree((*pmedata)->nny);
2596
    sfree((*pmedata)->nnz);
2597

    
2598
    pmegrids_destroy(&(*pmedata)->pmegridA);
2599

    
2600
    sfree((*pmedata)->fftgridA);
2601
    sfree((*pmedata)->cfftgridA);
2602
    gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupA);
2603

    
2604
    if ((*pmedata)->pmegridB.grid.grid != NULL)
2605
    {
2606
        pmegrids_destroy(&(*pmedata)->pmegridB);
2607
        sfree((*pmedata)->fftgridB);
2608
        sfree((*pmedata)->cfftgridB);
2609
        gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupB);
2610
    }
2611
    for(thread=0; thread<(*pmedata)->nthread; thread++)
2612
    {
2613
        free_work(&(*pmedata)->work[thread]);
2614
    }
2615
    sfree((*pmedata)->work);
2616

    
2617
    sfree(*pmedata);
2618
    *pmedata = NULL;
2619

    
2620
  return 0;
2621
}
2622

    
2623
static int mult_up(int n,int f)
2624
{
2625
    return ((n + f - 1)/f)*f;
2626
}
2627

    
2628

    
2629
static double pme_load_imbalance(gmx_pme_t pme)
2630
{
2631
    int    nma,nmi;
2632
    double n1,n2,n3;
2633

    
2634
    nma = pme->nnodes_major;
2635
    nmi = pme->nnodes_minor;
2636

    
2637
    n1 = mult_up(pme->nkx,nma)*mult_up(pme->nky,nmi)*pme->nkz;
2638
    n2 = mult_up(pme->nkx,nma)*mult_up(pme->nkz,nmi)*pme->nky;
2639
    n3 = mult_up(pme->nky,nma)*mult_up(pme->nkz,nmi)*pme->nkx;
2640

    
2641
    /* pme_solve is roughly double the cost of an fft */
2642

    
2643
    return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
2644
}
2645

    
2646
static void init_atomcomm(gmx_pme_t pme,pme_atomcomm_t *atc, t_commrec *cr,
2647
                          int dimind,gmx_bool bSpread)
2648
{
2649
    int nk,k,s,thread;
2650

    
2651
    atc->dimind = dimind;
2652
    atc->nslab  = 1;
2653
    atc->nodeid = 0;
2654
    atc->pd_nalloc = 0;
2655
#ifdef GMX_MPI
2656
    if (pme->nnodes > 1)
2657
    {
2658
        atc->mpi_comm = pme->mpi_comm_d[dimind];
2659
        MPI_Comm_size(atc->mpi_comm,&atc->nslab);
2660
        MPI_Comm_rank(atc->mpi_comm,&atc->nodeid);
2661
    }
2662
    if (debug)
2663
    {
2664
        fprintf(debug,"For PME atom communication in dimind %d: nslab %d rank %d\n",atc->dimind,atc->nslab,atc->nodeid);
2665
    }
2666
#endif
2667

    
2668
    atc->bSpread   = bSpread;
2669
    atc->pme_order = pme->pme_order;
2670

    
2671
    if (atc->nslab > 1)
2672
    {
2673
        /* These three allocations are not required for particle decomp. */
2674
        snew(atc->node_dest,atc->nslab);
2675
        snew(atc->node_src,atc->nslab);
2676
        setup_coordinate_communication(atc);
2677

    
2678
        snew(atc->count_thread,pme->nthread);
2679
        for(thread=0; thread<pme->nthread; thread++)
2680
        {
2681
            snew(atc->count_thread[thread],atc->nslab);
2682
        }
2683
        atc->count = atc->count_thread[0];
2684
        snew(atc->rcount,atc->nslab);
2685
        snew(atc->buf_index,atc->nslab);
2686
    }
2687

    
2688
    atc->nthread = pme->nthread;
2689
    if (atc->nthread > 1)
2690
    {
2691
        snew(atc->thread_plist,atc->nthread);
2692
    }
2693
    snew(atc->spline,atc->nthread);
2694
    for(thread=0; thread<atc->nthread; thread++)
2695
    {
2696
        if (atc->nthread > 1)
2697
        {
2698
            snew(atc->thread_plist[thread].n,atc->nthread+2*GMX_CACHE_SEP);
2699
            atc->thread_plist[thread].n += GMX_CACHE_SEP;
2700
        }
2701
    }
2702
}
2703

    
2704
static void
2705
init_overlap_comm(pme_overlap_t *  ol,
2706
                  int              norder,
2707
#ifdef GMX_MPI
2708
                  MPI_Comm         comm,
2709
#endif
2710
                  int              nnodes,
2711
                  int              nodeid,
2712
                  int              ndata,
2713
                  int              commplainsize)
2714
{
2715
    int lbnd,rbnd,maxlr,b,i;
2716
    int exten;
2717
    int nn,nk;
2718
    pme_grid_comm_t *pgc;
2719
    gmx_bool bCont;
2720
    int fft_start,fft_end,send_index1,recv_index1;
2721

    
2722
#ifdef GMX_MPI
2723
    ol->mpi_comm = comm;
2724
#endif
2725

    
2726
    ol->nnodes = nnodes;
2727
    ol->nodeid = nodeid;
2728

    
2729
    /* Linear translation of the PME grid wo'nt affect reciprocal space
2730
     * calculations, so to optimize we only interpolate "upwards",
2731
     * which also means we only have to consider overlap in one direction.
2732
     * I.e., particles on this node might also be spread to grid indices
2733
     * that belong to higher nodes (modulo nnodes)
2734
     */
2735

    
2736
    snew(ol->s2g0,ol->nnodes+1);
2737
    snew(ol->s2g1,ol->nnodes);
2738
    if (debug) { fprintf(debug,"PME slab boundaries:"); }
2739
    for(i=0; i<nnodes; i++)
2740
    {
2741
        /* s2g0 the local interpolation grid start.
2742
         * s2g1 the local interpolation grid end.
2743
         * Because grid overlap communication only goes forward,
2744
         * the grid the slabs for fft's should be rounded down.
2745
         */
2746
        ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
2747
        ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
2748

    
2749
        if (debug)
2750
        {
2751
            fprintf(debug,"  %3d %3d",ol->s2g0[i],ol->s2g1[i]);
2752
        }
2753
    }
2754
    ol->s2g0[nnodes] = ndata;
2755
    if (debug) { fprintf(debug,"\n"); }
2756

    
2757
    /* Determine with how many nodes we need to communicate the grid overlap */
2758
    b = 0;
2759
    do
2760
    {
2761
        b++;
2762
        bCont = FALSE;
2763
        for(i=0; i<nnodes; i++)
2764
        {
2765
            if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
2766
                (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
2767
            {
2768
                bCont = TRUE;
2769
            }
2770
        }
2771
    }
2772
    while (bCont && b < nnodes);
2773
    ol->noverlap_nodes = b - 1;
2774

    
2775
    snew(ol->send_id,ol->noverlap_nodes);
2776
    snew(ol->recv_id,ol->noverlap_nodes);
2777
    for(b=0; b<ol->noverlap_nodes; b++)
2778
    {
2779
        ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
2780
        ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
2781
    }
2782
    snew(ol->comm_data, ol->noverlap_nodes);
2783

    
2784
    for(b=0; b<ol->noverlap_nodes; b++)
2785
    {
2786
        pgc = &ol->comm_data[b];
2787
        /* Send */
2788
        fft_start        = ol->s2g0[ol->send_id[b]];
2789
        fft_end          = ol->s2g0[ol->send_id[b]+1];
2790
        if (ol->send_id[b] < nodeid)
2791
        {
2792
            fft_start += ndata;
2793
            fft_end   += ndata;
2794
        }
2795
        send_index1      = ol->s2g1[nodeid];
2796
        send_index1      = min(send_index1,fft_end);
2797
        pgc->send_index0 = fft_start;
2798
        pgc->send_nindex = max(0,send_index1 - pgc->send_index0);
2799

    
2800
        /* We always start receiving to the first index of our slab */
2801
        fft_start        = ol->s2g0[ol->nodeid];
2802
        fft_end          = ol->s2g0[ol->nodeid+1];
2803
        recv_index1      = ol->s2g1[ol->recv_id[b]];
2804
        if (ol->recv_id[b] > nodeid)
2805
        {
2806
            recv_index1 -= ndata;
2807
        }
2808
        recv_index1      = min(recv_index1,fft_end);
2809
        pgc->recv_index0 = fft_start;
2810
        pgc->recv_nindex = max(0,recv_index1 - pgc->recv_index0);
2811
    }
2812

    
2813
    /* For non-divisible grid we need pme_order iso pme_order-1 */
2814
    snew(ol->sendbuf,norder*commplainsize);
2815
    snew(ol->recvbuf,norder*commplainsize);
2816
}
2817

    
2818
static void
2819
make_gridindex5_to_localindex(int n,int local_start,int local_range,
2820
                              int **global_to_local,
2821
                              real **fraction_shift)
2822
{
2823
    int i;
2824
    int * gtl;
2825
    real * fsh;
2826

    
2827
    snew(gtl,5*n);
2828
    snew(fsh,5*n);
2829
    for(i=0; (i<5*n); i++)
2830
    {
2831
        /* Determine the global to local grid index */
2832
        gtl[i] = (i - local_start + n) % n;
2833
        /* For coordinates that fall within the local grid the fraction
2834
         * is correct, we don't need to shift it.
2835
         */
2836
        fsh[i] = 0;
2837
        if (local_range < n)
2838
        {
2839
            /* Due to rounding issues i could be 1 beyond the lower or
2840
             * upper boundary of the local grid. Correct the index for this.
2841
             * If we shift the index, we need to shift the fraction by
2842
             * the same amount in the other direction to not affect
2843
             * the weights.
2844
             * Note that due to this shifting the weights at the end of
2845
             * the spline might change, but that will only involve values
2846
             * between zero and values close to the precision of a real,
2847
             * which is anyhow the accuracy of the whole mesh calculation.
2848
             */
2849
            /* With local_range=0 we should not change i=local_start */
2850
            if (i % n != local_start)
2851
            {
2852
                if (gtl[i] == n-1)
2853
                {
2854
                    gtl[i] = 0;
2855
                    fsh[i] = -1;
2856
                }
2857
                else if (gtl[i] == local_range)
2858
                {
2859
                    gtl[i] = local_range - 1;
2860
                    fsh[i] = 1;
2861
                }
2862
            }
2863
        }
2864
    }
2865

    
2866
    *global_to_local = gtl;
2867
    *fraction_shift  = fsh;
2868
}
2869

    
2870
static pme_spline_work_t *make_pme_spline_work(int order)
2871
{
2872
    pme_spline_work_t *work;
2873

    
2874
#ifdef PME_SSE
2875
    float  tmp[8];
2876
    __m128 zero_SSE;
2877
    int    of,i;
2878

    
2879
    snew_aligned(work,1,16);
2880

    
2881
    zero_SSE = _mm_setzero_ps();
2882

    
2883
    /* Generate bit masks to mask out the unused grid entries,
2884
     * as we only operate on order of the 8 grid entries that are
2885
     * load into 2 SSE float registers.
2886
     */
2887
    for(of=0; of<8-(order-1); of++)
2888
    {
2889
        for(i=0; i<8; i++)
2890
        {
2891
            tmp[i] = (i >= of && i < of+order ? 1 : 0);
2892
        }
2893
        work->mask_SSE0[of] = _mm_loadu_ps(tmp);
2894
        work->mask_SSE1[of] = _mm_loadu_ps(tmp+4);
2895
        work->mask_SSE0[of] = _mm_cmpgt_ps(work->mask_SSE0[of],zero_SSE);
2896
        work->mask_SSE1[of] = _mm_cmpgt_ps(work->mask_SSE1[of],zero_SSE);
2897
    }
2898
#else
2899
    work = NULL;
2900
#endif
2901

    
2902
    return work;
2903
}
2904

    
2905
static void
2906
gmx_pme_check_grid_restrictions(FILE *fplog,char dim,int nnodes,int *nk)
2907
{
2908
    int nk_new;
2909

    
2910
    if (*nk % nnodes != 0)
2911
    {
2912
        nk_new = nnodes*(*nk/nnodes + 1);
2913

    
2914
        if (2*nk_new >= 3*(*nk))
2915
        {
2916
            gmx_fatal(FARGS,"The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). The grid size would have to be increased by more than 50%% to make the grid divisible. Change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).",
2917
                      dim,*nk,dim,nnodes,dim);
2918
        }
2919

    
2920
        if (fplog != NULL)
2921
        {
2922
            fprintf(fplog,"\nNOTE: The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). Increasing the PME grid size in dim %c to %d. This will increase the accuracy and will not decrease the performance significantly on this number of nodes. For optimal performance change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).\n\n",
2923
                    dim,*nk,dim,nnodes,dim,nk_new,dim);
2924
        }
2925

    
2926
        *nk = nk_new;
2927
    }
2928
}
2929

    
2930
int gmx_pme_init(gmx_pme_t *         pmedata,
2931
                 t_commrec *         cr,
2932
                 int                 nnodes_major,
2933
                 int                 nnodes_minor,
2934
                 t_inputrec *        ir,
2935
                 int                 homenr,
2936
                 gmx_bool            bFreeEnergy,
2937
                 gmx_bool            bReproducible,
2938
                 int                 nthread)
2939
{
2940
    gmx_pme_t pme=NULL;
2941

    
2942
    pme_atomcomm_t *atc;
2943
    ivec ndata;
2944

    
2945
    if (debug)
2946
        fprintf(debug,"Creating PME data structures.\n");
2947
    snew(pme,1);
2948

    
2949
    pme->redist_init         = FALSE;
2950
    pme->sum_qgrid_tmp       = NULL;
2951
    pme->sum_qgrid_dd_tmp    = NULL;
2952
    pme->buf_nalloc          = 0;
2953
    pme->redist_buf_nalloc   = 0;
2954

    
2955
    pme->nnodes              = 1;
2956
    pme->bPPnode             = TRUE;
2957

    
2958
    pme->nnodes_major        = nnodes_major;
2959
    pme->nnodes_minor        = nnodes_minor;
2960

    
2961
#ifdef GMX_MPI
2962
    if (nnodes_major*nnodes_minor > 1)
2963
    {
2964
        pme->mpi_comm = cr->mpi_comm_mygroup;
2965

    
2966
        MPI_Comm_rank(pme->mpi_comm,&pme->nodeid);
2967
        MPI_Comm_size(pme->mpi_comm,&pme->nnodes);
2968
        if (pme->nnodes != nnodes_major*nnodes_minor)
2969
        {
2970
            gmx_incons("PME node count mismatch");
2971
        }
2972
    }
2973
    else
2974
    {
2975
        pme->mpi_comm = MPI_COMM_NULL;
2976
    }
2977
#endif
2978

    
2979
    if (pme->nnodes == 1)
2980
    {
2981
        pme->ndecompdim = 0;
2982
        pme->nodeid_major = 0;
2983
        pme->nodeid_minor = 0;
2984
#ifdef GMX_MPI
2985
        pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
2986
#endif
2987
    }
2988
    else
2989
    {
2990
        if (nnodes_minor == 1)
2991
        {
2992
#ifdef GMX_MPI
2993
            pme->mpi_comm_d[0] = pme->mpi_comm;
2994
            pme->mpi_comm_d[1] = MPI_COMM_NULL;
2995
#endif
2996
            pme->ndecompdim = 1;
2997
            pme->nodeid_major = pme->nodeid;
2998
            pme->nodeid_minor = 0;
2999

    
3000
        }
3001
        else if (nnodes_major == 1)
3002
        {
3003
#ifdef GMX_MPI
3004
            pme->mpi_comm_d[0] = MPI_COMM_NULL;
3005
            pme->mpi_comm_d[1] = pme->mpi_comm;
3006
#endif
3007
            pme->ndecompdim = 1;
3008
            pme->nodeid_major = 0;
3009
            pme->nodeid_minor = pme->nodeid;
3010
        }
3011
        else
3012
        {
3013
            if (pme->nnodes % nnodes_major != 0)
3014
            {
3015
                gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
3016
            }
3017
            pme->ndecompdim = 2;
3018

    
3019
#ifdef GMX_MPI
3020
            MPI_Comm_split(pme->mpi_comm,pme->nodeid % nnodes_minor,
3021
                           pme->nodeid,&pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3022
            MPI_Comm_split(pme->mpi_comm,pme->nodeid/nnodes_minor,
3023
                           pme->nodeid,&pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3024

    
3025
            MPI_Comm_rank(pme->mpi_comm_d[0],&pme->nodeid_major);
3026
            MPI_Comm_size(pme->mpi_comm_d[0],&pme->nnodes_major);
3027
            MPI_Comm_rank(pme->mpi_comm_d[1],&pme->nodeid_minor);
3028
            MPI_Comm_size(pme->mpi_comm_d[1],&pme->nnodes_minor);
3029
#endif
3030
        }
3031
        pme->bPPnode = (cr->duty & DUTY_PP);
3032
    }
3033

    
3034
    pme->nthread = nthread;
3035

    
3036
    if (ir->ePBC == epbcSCREW)
3037
    {
3038
        gmx_fatal(FARGS,"pme does not (yet) work with pbc = screw");
3039
    }
3040

    
3041
    pme->bFEP        = ((ir->efep != efepNO) && bFreeEnergy);
3042
    pme->nkx         = ir->nkx;
3043
    pme->nky         = ir->nky;
3044
    pme->nkz         = ir->nkz;
3045
    pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3046
    pme->pme_order   = ir->pme_order;
3047
    pme->epsilon_r   = ir->epsilon_r;
3048

    
3049
    if (pme->pme_order > PME_ORDER_MAX)
3050
    {
3051
        gmx_fatal(FARGS,"pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3052
                  pme->pme_order,PME_ORDER_MAX);
3053
    }
3054

    
3055
    /* Currently pme.c supports only the fft5d FFT code.
3056
     * Therefore the grid always needs to be divisible by nnodes.
3057
     * When the old 1D code is also supported again, change this check.
3058
     *
3059
     * This check should be done before calling gmx_pme_init
3060
     * and fplog should be passed iso stderr.
3061
     *
3062
    if (pme->ndecompdim >= 2)
3063
    */
3064
    if (pme->ndecompdim >= 1)
3065
    {
3066
        /*
3067
        gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3068
                                        'x',nnodes_major,&pme->nkx);
3069
        gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3070
                                        'y',nnodes_minor,&pme->nky);
3071
        */
3072
    }
3073

    
3074
    if (pme->nkx <= pme->pme_order*(pme->nnodes_major > 1 ? 2 : 1) ||
3075
        pme->nky <= pme->pme_order*(pme->nnodes_minor > 1 ? 2 : 1) ||
3076
        pme->nkz <= pme->pme_order)
3077
    {
3078
        gmx_fatal(FARGS,"The pme grid dimensions need to be larger than pme_order (%d) and in parallel larger than 2*pme_ordern for x and/or y",pme->pme_order);
3079
    }
3080

    
3081
    if (pme->nnodes > 1) {
3082
        double imbal;
3083

    
3084
#ifdef GMX_MPI
3085
        MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3086
        MPI_Type_commit(&(pme->rvec_mpi));
3087
#endif
3088

    
3089
        /* Note that the charge spreading and force gathering, which usually
3090
         * takes about the same amount of time as FFT+solve_pme,
3091
         * is always fully load balanced
3092
         * (unless the charge distribution is inhomogeneous).
3093
         */
3094

    
3095
        imbal = pme_load_imbalance(pme);
3096
        if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3097
        {
3098
            fprintf(stderr,
3099
                    "\n"
3100
                    "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3101
                    "      For optimal PME load balancing\n"
3102
                    "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3103
                    "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3104
                    "\n",
3105
                    (int)((imbal-1)*100 + 0.5),
3106
                    pme->nkx,pme->nky,pme->nnodes_major,
3107
                    pme->nky,pme->nkz,pme->nnodes_minor);
3108
        }
3109
    }
3110

    
3111
    /* For non-divisible grid we need pme_order iso pme_order-1 */
3112
    /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3113
     * y is always copied through a buffer: we don't need padding in z,
3114
     * but we do need the overlap in x because of the communication order.
3115
     */
3116
    init_overlap_comm(&pme->overlap[0],pme->pme_order,
3117
#ifdef GMX_MPI
3118
                      pme->mpi_comm_d[0],
3119
#endif
3120
                      pme->nnodes_major,pme->nodeid_major,
3121
                      pme->nkx,
3122
                      (div_round_up(pme->nky,pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3123

    
3124
    init_overlap_comm(&pme->overlap[1],pme->pme_order,
3125
#ifdef GMX_MPI
3126
                      pme->mpi_comm_d[1],
3127
#endif
3128
                      pme->nnodes_minor,pme->nodeid_minor,
3129
                      pme->nky,
3130
                      (div_round_up(pme->nkx,pme->nnodes_major)+pme->pme_order)*pme->nkz);
3131

    
3132
    /* Check for a limitation of the (current) sum_fftgrid_dd code */
3133
    if (pme->nthread > 1 &&
3134
        (pme->overlap[0].noverlap_nodes > 1 ||
3135
         pme->overlap[1].noverlap_nodes > 1))
3136
    {
3137
        gmx_fatal(FARGS,"With threads the number of grid lines per node along x and or y should be pme_order (%d) or more or exactly pme_order-1",pme->pme_order);
3138
    }
3139

    
3140
    snew(pme->bsp_mod[XX],pme->nkx);
3141
    snew(pme->bsp_mod[YY],pme->nky);
3142
    snew(pme->bsp_mod[ZZ],pme->nkz);
3143

    
3144
    /* The required size of the interpolation grid, including overlap.
3145
     * The allocated size (pmegrid_n?) might be slightly larger.
3146
     */
3147
    pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3148
                      pme->overlap[0].s2g0[pme->nodeid_major];
3149
    pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3150
                      pme->overlap[1].s2g0[pme->nodeid_minor];
3151
    pme->pmegrid_nz_base = pme->nkz;
3152
    pme->pmegrid_nz = pme->pmegrid_nz_base + pme->pme_order - 1;
3153
    set_grid_alignment(&pme->pmegrid_nz,pme->pme_order);
3154

    
3155
    pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3156
    pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3157
    pme->pmegrid_start_iz = 0;
3158

    
3159
    make_gridindex5_to_localindex(pme->nkx,
3160
                                  pme->pmegrid_start_ix,
3161
                                  pme->pmegrid_nx - (pme->pme_order-1),
3162
                                  &pme->nnx,&pme->fshx);
3163
    make_gridindex5_to_localindex(pme->nky,
3164
                                  pme->pmegrid_start_iy,
3165
                                  pme->pmegrid_ny - (pme->pme_order-1),
3166
                                  &pme->nny,&pme->fshy);
3167
    make_gridindex5_to_localindex(pme->nkz,
3168
                                  pme->pmegrid_start_iz,
3169
                                  pme->pmegrid_nz_base,
3170
                                  &pme->nnz,&pme->fshz);
3171

    
3172
    pmegrids_init(&pme->pmegridA,
3173
                  pme->pmegrid_nx,pme->pmegrid_ny,pme->pmegrid_nz,
3174
                  pme->pmegrid_nz_base,
3175
                  pme->pme_order,
3176
                  pme->nthread,
3177
                  pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3178
                  pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3179

    
3180
    pme->spline_work = make_pme_spline_work(pme->pme_order);
3181

    
3182
    ndata[0] = pme->nkx;
3183
    ndata[1] = pme->nky;
3184
    ndata[2] = pme->nkz;
3185

    
3186
    /* This routine will allocate the grid data to fit the FFTs */
3187
    gmx_parallel_3dfft_init(&pme->pfft_setupA,ndata,
3188
                            &pme->fftgridA,&pme->cfftgridA,
3189
                            pme->mpi_comm_d,
3190
                            pme->overlap[0].s2g0,pme->overlap[1].s2g0,
3191
                            bReproducible,pme->nthread);
3192

    
3193
    if (bFreeEnergy)
3194
    {
3195
        pmegrids_init(&pme->pmegridB,
3196
                      pme->pmegrid_nx,pme->pmegrid_ny,pme->pmegrid_nz,
3197
                      pme->pmegrid_nz_base,
3198
                      pme->pme_order,
3199
                      pme->nthread,
3200
                      pme->nkx % pme->nnodes_major != 0,
3201
                      pme->nky % pme->nnodes_minor != 0);
3202

    
3203
        gmx_parallel_3dfft_init(&pme->pfft_setupB,ndata,
3204
                                &pme->fftgridB,&pme->cfftgridB,
3205
                                pme->mpi_comm_d,
3206
                                pme->overlap[0].s2g0,pme->overlap[1].s2g0,
3207
                                bReproducible,pme->nthread);
3208
    }
3209
    else
3210
    {
3211
        pme->pmegridB.grid.grid = NULL;
3212
        pme->fftgridB           = NULL;
3213
        pme->cfftgridB          = NULL;
3214
    }
3215

    
3216
    if (!pme->bP3M)
3217
    {
3218
        /* Use plain SPME B-spline interpolation */
3219
        make_bspline_moduli(pme->bsp_mod,pme->nkx,pme->nky,pme->nkz,pme->pme_order);
3220
    }
3221
    else
3222
    {
3223
        /* Use the P3M grid-optimized influence function */
3224
        make_p3m_bspline_moduli(pme->bsp_mod,pme->nkx,pme->nky,pme->nkz,pme->pme_order);
3225
    }
3226

    
3227
    /* Use atc[0] for spreading */
3228
    init_atomcomm(pme,&pme->atc[0],cr,nnodes_major > 1 ? 0 : 1,TRUE);
3229
    if (pme->ndecompdim >= 2)
3230
    {
3231
        init_atomcomm(pme,&pme->atc[1],cr,1,FALSE);
3232
    }
3233

    
3234
    if (pme->nnodes == 1) {
3235
        pme->atc[0].n = homenr;
3236
        pme_realloc_atomcomm_things(&pme->atc[0]);
3237
    }
3238

    
3239
    {
3240
        int thread;
3241

    
3242
        /* Use fft5d, order after FFT is y major, z, x minor */
3243

    
3244
        snew(pme->work,pme->nthread);
3245
        for(thread=0; thread<pme->nthread; thread++)
3246
        {
3247
            realloc_work(&pme->work[thread],pme->nkx);
3248
        }
3249
    }
3250

    
3251
    *pmedata = pme;
3252

    
3253
    return 0;
3254
}
3255

    
3256

    
3257
static void copy_local_grid(gmx_pme_t pme,
3258
                            pmegrids_t *pmegrids,int thread,real *fftgrid)
3259
{
3260
    ivec local_fft_ndata,local_fft_offset,local_fft_size;
3261
    int  fft_my,fft_mz;
3262
    int  nsx,nsy,nsz;
3263
    ivec nf;
3264
    int  offx,offy,offz,x,y,z,i0,i0t;
3265
    int  d;
3266
    pmegrid_t *pmegrid;
3267
    real *grid_th;
3268

    
3269
    gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3270
                                   local_fft_ndata,
3271
                                   local_fft_offset,
3272
                                   local_fft_size);
3273
    fft_my = local_fft_size[YY];
3274
    fft_mz = local_fft_size[ZZ];
3275

    
3276
    pmegrid = &pmegrids->grid_th[thread];
3277

    
3278
    nsx = pmegrid->n[XX];
3279
    nsy = pmegrid->n[YY];
3280
    nsz = pmegrid->n[ZZ];
3281

    
3282
    for(d=0; d<DIM; d++)
3283
    {
3284
        nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3285
                    local_fft_ndata[d] - pmegrid->offset[d]);
3286
    }
3287

    
3288
    offx = pmegrid->offset[XX];
3289
    offy = pmegrid->offset[YY];
3290
    offz = pmegrid->offset[ZZ];
3291

    
3292
    /* Directly copy the non-overlapping parts of the local grids.
3293
     * This also initializes the full grid.
3294
     */
3295
    grid_th = pmegrid->grid;
3296
    for(x=0; x<nf[XX]; x++)
3297
    {
3298
        for(y=0; y<nf[YY]; y++)
3299
        {
3300
            i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3301
            i0t = (x*nsy + y)*nsz;
3302
            for(z=0; z<nf[ZZ]; z++)
3303
            {
3304
                fftgrid[i0+z] = grid_th[i0t+z];
3305
            }
3306
        }
3307
    }
3308
}
3309

    
3310
static void print_sendbuf(gmx_pme_t pme,real *sendbuf)
3311
{
3312
    ivec local_fft_ndata,local_fft_offset,local_fft_size;
3313
    pme_overlap_t *overlap;
3314
    int datasize,nind;
3315
    int i,x,y,z,n;
3316

    
3317
    gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3318
                                   local_fft_ndata,
3319
                                   local_fft_offset,
3320
                                   local_fft_size);
3321
    /* Major dimension */
3322
    overlap = &pme->overlap[0];
3323

    
3324
    nind   = overlap->comm_data[0].send_nindex;
3325

    
3326
    for(y=0; y<local_fft_ndata[YY]; y++) {
3327
         printf(" %2d",y);
3328
    }
3329
    printf("\n");
3330

    
3331
    i = 0;
3332
    for(x=0; x<nind; x++) {
3333
        for(y=0; y<local_fft_ndata[YY]; y++) {
3334
            n = 0;
3335
            for(z=0; z<local_fft_ndata[ZZ]; z++) {
3336
                if (sendbuf[i] != 0) n++;
3337
                i++;
3338
            }
3339
            printf(" %2d",n);
3340
        }
3341
        printf("\n");
3342
    }
3343
}
3344

    
3345
static void
3346
reduce_threadgrid_overlap(gmx_pme_t pme,
3347
                          const pmegrids_t *pmegrids,int thread,
3348
                          real *fftgrid,real *commbuf_x,real *commbuf_y)
3349
{
3350
    ivec local_fft_ndata,local_fft_offset,local_fft_size;
3351
    int  fft_nx,fft_ny,fft_nz;
3352
    int  fft_my,fft_mz;
3353
    int  buf_my=-1;
3354
    int  nsx,nsy,nsz;
3355
    ivec ne;
3356
    int  offx,offy,offz,x,y,z,i0,i0t;
3357
    int  sx,sy,sz,fx,fy,fz,tx1,ty1,tz1,ox,oy,oz;
3358
    gmx_bool bClearBufX,bClearBufY,bClearBufXY,bClearBuf;
3359
    gmx_bool bCommX,bCommY;
3360
    int  d;
3361
    int  thread_f;
3362
    const pmegrid_t *pmegrid,*pmegrid_g,*pmegrid_f;
3363
    const real *grid_th;
3364
    real *commbuf=NULL;
3365

    
3366
    gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3367
                                   local_fft_ndata,
3368
                                   local_fft_offset,
3369
                                   local_fft_size);
3370
    fft_nx = local_fft_ndata[XX];
3371
    fft_ny = local_fft_ndata[YY];
3372
    fft_nz = local_fft_ndata[ZZ];
3373

    
3374
    fft_my = local_fft_size[YY];
3375
    fft_mz = local_fft_size[ZZ];
3376

    
3377
    /* This routine is called when all thread have finished spreading.
3378
     * Here each thread sums grid contributions calculated by other threads
3379
     * to the thread local grid volume.
3380
     * To minimize the number of grid copying operations,
3381
     * this routines sums immediately from the pmegrid to the fftgrid.
3382
     */
3383

    
3384
    /* Determine which part of the full node grid we should operate on,
3385
     * this is our thread local part of the full grid.
3386
     */
3387
    pmegrid = &pmegrids->grid_th[thread];
3388

    
3389
    for(d=0; d<DIM; d++)
3390
    {
3391
        ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3392
                    local_fft_ndata[d]);
3393
    }
3394

    
3395
    offx = pmegrid->offset[XX];
3396
    offy = pmegrid->offset[YY];
3397
    offz = pmegrid->offset[ZZ];
3398

    
3399

    
3400
    bClearBufX  = TRUE;
3401
    bClearBufY  = TRUE;
3402
    bClearBufXY = TRUE;
3403

    
3404
    /* Now loop over all the thread data blocks that contribute
3405
     * to the grid region we (our thread) are operating on.
3406
     */
3407
    /* Note that ffy_nx/y is equal to the number of grid points
3408
     * between the first point of our node grid and the one of the next node.
3409
     */
3410
    for(sx=0; sx>=-pmegrids->nthread_comm[XX]; sx--)
3411
    {
3412
        fx = pmegrid->ci[XX] + sx;
3413
        ox = 0;
3414
        bCommX = FALSE;
3415
        if (fx < 0) {
3416
            fx += pmegrids->nc[XX];
3417
            ox -= fft_nx;
3418
            bCommX = (pme->nnodes_major > 1);
3419
        }
3420
        pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3421
        ox += pmegrid_g->offset[XX];
3422
        if (!bCommX)
3423
        {
3424
            tx1 = min(ox + pmegrid_g->n[XX],ne[XX]);
3425
        }
3426
        else
3427
        {
3428
            tx1 = min(ox + pmegrid_g->n[XX],pme->pme_order);
3429
        }
3430

    
3431
        for(sy=0; sy>=-pmegrids->nthread_comm[YY]; sy--)
3432
        {
3433
            fy = pmegrid->ci[YY] + sy;
3434
            oy = 0;
3435
            bCommY = FALSE;
3436
            if (fy < 0) {
3437
                fy += pmegrids->nc[YY];
3438
                oy -= fft_ny;
3439
                bCommY = (pme->nnodes_minor > 1);
3440
            }
3441
            pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3442
            oy += pmegrid_g->offset[YY];
3443
            if (!bCommY)
3444
            {
3445
                ty1 = min(oy + pmegrid_g->n[YY],ne[YY]);
3446
            }
3447
            else
3448
            {
3449
                ty1 = min(oy + pmegrid_g->n[YY],pme->pme_order);
3450
            }
3451

    
3452
            for(sz=0; sz>=-pmegrids->nthread_comm[ZZ]; sz--)
3453
            {
3454
                fz = pmegrid->ci[ZZ] + sz;
3455
                oz = 0;
3456
                if (fz < 0)
3457
                {
3458
                    fz += pmegrids->nc[ZZ];
3459
                    oz -= fft_nz;
3460
                }
3461
                pmegrid_g = &pmegrids->grid_th[fz];
3462
                oz += pmegrid_g->offset[ZZ];
3463
                tz1 = min(oz + pmegrid_g->n[ZZ],ne[ZZ]);
3464

    
3465
                if (sx == 0 && sy == 0 && sz == 0)
3466
                {
3467
                    /* We have already added our local contribution
3468
                     * before calling this routine, so skip it here.
3469
                     */
3470
                    continue;
3471
                }
3472

    
3473
                thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3474

    
3475
                pmegrid_f = &pmegrids->grid_th[thread_f];
3476

    
3477
                grid_th = pmegrid_f->grid;
3478

    
3479
                nsx = pmegrid_f->n[XX];
3480
                nsy = pmegrid_f->n[YY];
3481
                nsz = pmegrid_f->n[ZZ];
3482

    
3483
#ifdef DEBUG_PME_REDUCE
3484
                printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3485
                       pme->nodeid,thread,thread_f,
3486
                       pme->pmegrid_start_ix,
3487
                       pme->pmegrid_start_iy,
3488
                       pme->pmegrid_start_iz,
3489
                       sx,sy,sz,
3490
                       offx-ox,tx1-ox,offx,tx1,
3491
                       offy-oy,ty1-oy,offy,ty1,
3492
                       offz-oz,tz1-oz,offz,tz1);
3493
#endif
3494

    
3495
                if (!(bCommX || bCommY))
3496
                {
3497
                    /* Copy from the thread local grid to the node grid */
3498
                    for(x=offx; x<tx1; x++)
3499
                    {
3500
                        for(y=offy; y<ty1; y++)
3501
                        {
3502
                            i0  = (x*fft_my + y)*fft_mz;
3503
                            i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3504
                            for(z=offz; z<tz1; z++)
3505
                            {
3506
                                fftgrid[i0+z] += grid_th[i0t+z];
3507
                            }
3508
                        }
3509
                    }
3510
                }
3511
                else
3512
                {
3513
                    /* The order of this conditional decides
3514
                     * where the corner volume gets stored with x+y decomp.
3515
                     */
3516
                    if (bCommY)
3517
                    {
3518
                        commbuf = commbuf_y;
3519
                        buf_my  = ty1 - offy;
3520
                        if (bCommX)
3521
                        {
3522
                            /* We index commbuf modulo the local grid size */
3523
                            commbuf += buf_my*fft_nx*fft_nz;
3524

    
3525
                            bClearBuf  = bClearBufXY;
3526
                            bClearBufXY = FALSE;
3527
                        }
3528
                        else
3529
                        {
3530
                            bClearBuf  = bClearBufY;
3531
                            bClearBufY = FALSE;
3532
                        }
3533
                    }
3534
                    else
3535
                    {
3536
                        commbuf = commbuf_x;
3537
                        buf_my  = fft_ny;
3538
                        bClearBuf  = bClearBufX;
3539
                        bClearBufX = FALSE;
3540
                    }
3541

    
3542
                    /* Copy to the communication buffer */
3543
                    for(x=offx; x<tx1; x++)
3544
                    {
3545
                        for(y=offy; y<ty1; y++)
3546
                        {
3547
                            i0  = (x*buf_my + y)*fft_nz;
3548
                            i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3549

    
3550
                            if (bClearBuf)
3551
                            {
3552
                                /* First access of commbuf, initialize it */
3553
                                for(z=offz; z<tz1; z++)
3554
                                {
3555
                                    commbuf[i0+z]  = grid_th[i0t+z];
3556
                                }
3557
                            }
3558
                            else
3559
                            {
3560
                                for(z=offz; z<tz1; z++)
3561
                                {
3562
                                    commbuf[i0+z] += grid_th[i0t+z];
3563
                                }
3564
                            }
3565
                        }
3566
                    }
3567
                }
3568
            }
3569
        }
3570
    }
3571
}
3572

    
3573

    
3574
static void sum_fftgrid_dd(gmx_pme_t pme,real *fftgrid)
3575
{
3576
    ivec local_fft_ndata,local_fft_offset,local_fft_size;
3577
    pme_overlap_t *overlap;
3578
    int  send_nindex;
3579
    int  recv_index0,recv_nindex;
3580
#ifdef GMX_MPI
3581
    MPI_Status stat;
3582
#endif
3583
    int  ipulse,send_id,recv_id,datasize,gridsize,size_yx;
3584
    real *sendptr,*recvptr;
3585
    int  x,y,z,indg,indb;
3586

    
3587
    /* Note that this routine is only used for forward communication.
3588
     * Since the force gathering, unlike the charge spreading,
3589
     * can be trivially parallelized over the particles,
3590
     * the backwards process is much simpler and can use the "old"
3591
     * communication setup.
3592
     */
3593

    
3594
    gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3595
                                   local_fft_ndata,
3596
                                   local_fft_offset,
3597
                                   local_fft_size);
3598

    
3599
    /* Currently supports only a single communication pulse */
3600

    
3601
/* for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++) */
3602
    if (pme->nnodes_minor > 1)
3603
    {
3604
        /* Major dimension */
3605
        overlap = &pme->overlap[1];
3606

    
3607
        if (pme->nnodes_major > 1)
3608
        {
3609
             size_yx = pme->overlap[0].comm_data[0].send_nindex;
3610
        }
3611
        else
3612
        {
3613
            size_yx = 0;
3614
        }
3615
        datasize = (local_fft_ndata[XX]+size_yx)*local_fft_ndata[ZZ];
3616

    
3617
        ipulse = 0;
3618

    
3619
        send_id = overlap->send_id[ipulse];
3620
        recv_id = overlap->recv_id[ipulse];
3621
        send_nindex   = overlap->comm_data[ipulse].send_nindex;
3622
        /* recv_index0   = overlap->comm_data[ipulse].recv_index0; */
3623
        recv_index0 = 0;
3624
        recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3625

    
3626
        sendptr = overlap->sendbuf;
3627
        recvptr = overlap->recvbuf;
3628

    
3629
        /*
3630
        printf("node %d comm %2d x %2d x %2d\n",pme->nodeid,
3631
               local_fft_ndata[XX]+size_yx,send_nindex,local_fft_ndata[ZZ]);
3632
        printf("node %d send %f, %f\n",pme->nodeid,
3633
               sendptr[0],sendptr[send_nindex*datasize-1]);
3634
        */
3635

    
3636
#ifdef GMX_MPI
3637
        MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
3638
                     send_id,ipulse,
3639
                     recvptr,recv_nindex*datasize,GMX_MPI_REAL,
3640
                     recv_id,ipulse,
3641
                     overlap->mpi_comm,&stat);
3642
#endif
3643

    
3644
        for(x=0; x<local_fft_ndata[XX]; x++)
3645
        {
3646
            for(y=0; y<recv_nindex; y++)
3647
            {
3648
                indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
3649
                indb = (x*recv_nindex        + y)*local_fft_ndata[ZZ];
3650
                for(z=0; z<local_fft_ndata[ZZ]; z++)
3651
                {
3652
                    fftgrid[indg+z] += recvptr[indb+z];
3653
                }
3654
            }
3655
        }
3656
        if (pme->nnodes_major > 1)
3657
        {
3658
            sendptr = pme->overlap[0].sendbuf;
3659
            for(x=0; x<size_yx; x++)
3660
            {
3661
                for(y=0; y<recv_nindex; y++)
3662
                {
3663
                    indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3664
                    indb = ((local_fft_ndata[XX] + x)*recv_nindex +y)*local_fft_ndata[ZZ];
3665
                    for(z=0; z<local_fft_ndata[ZZ]; z++)
3666
                    {
3667
                        sendptr[indg+z] += recvptr[indb+z];
3668
                    }
3669
                }
3670
            }
3671
        }
3672
    }
3673

    
3674
    /* for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++) */
3675
    if (pme->nnodes_major > 1)
3676
    {
3677
        /* Major dimension */
3678
        overlap = &pme->overlap[0];
3679

    
3680
        datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
3681
        gridsize = local_fft_size[YY] *local_fft_size[ZZ];
3682

    
3683
        ipulse = 0;
3684

    
3685
        send_id = overlap->send_id[ipulse];
3686
        recv_id = overlap->recv_id[ipulse];
3687
        send_nindex   = overlap->comm_data[ipulse].send_nindex;
3688
        /* recv_index0   = overlap->comm_data[ipulse].recv_index0; */
3689
        recv_index0 = 0;
3690
        recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3691

    
3692
        sendptr = overlap->sendbuf;
3693
        recvptr = overlap->recvbuf;
3694

    
3695
        if (debug != NULL)
3696
        {
3697
            fprintf(debug,"PME fftgrid comm %2d x %2d x %2d\n",
3698
                   send_nindex,local_fft_ndata[YY],local_fft_ndata[ZZ]);
3699
        }
3700

    
3701
#ifdef GMX_MPI
3702
        MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
3703
                     send_id,ipulse,
3704
                     recvptr,recv_nindex*datasize,GMX_MPI_REAL,
3705
                     recv_id,ipulse,
3706
                     overlap->mpi_comm,&stat);
3707
#endif
3708

    
3709
        for(x=0; x<recv_nindex; x++)
3710
        {
3711
            for(y=0; y<local_fft_ndata[YY]; y++)
3712
            {
3713
                indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
3714
                indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3715
                for(z=0; z<local_fft_ndata[ZZ]; z++)
3716
                {
3717
                    fftgrid[indg+z] += recvptr[indb+z];
3718
                }
3719
            }
3720
        }
3721
    }
3722
}
3723

    
3724

    
3725
static void spread_on_grid(gmx_pme_t pme,
3726
                           pme_atomcomm_t *atc,pmegrids_t *grids,
3727
                           gmx_bool bCalcSplines,gmx_bool bSpread,
3728
                           real *fftgrid)
3729
{
3730
    int nthread,thread;
3731
#ifdef PME_TIME_THREADS
3732
    gmx_cycles_t c1,c2,c3,ct1a,ct1b,ct1c;
3733
    static double cs1=0,cs2=0,cs3=0;
3734
    static double cs1a[6]={0,0,0,0,0,0};
3735
    static int cnt=0;
3736
#endif
3737

    
3738
    nthread = pme->nthread;
3739
    assert(nthread>0);
3740

    
3741
#ifdef PME_TIME_THREADS
3742
    c1 = omp_cyc_start();
3743
#endif
3744
    if (bCalcSplines)
3745
    {
3746
#pragma omp parallel for num_threads(nthread) schedule(static)
3747
        for(thread=0; thread<nthread; thread++)
3748
        {
3749
            int start,end;
3750

    
3751
            start = atc->n* thread   /nthread;
3752
            end   = atc->n*(thread+1)/nthread;
3753

    
3754
            /* Compute fftgrid index for all atoms,
3755
             * with help of some extra variables.
3756
             */
3757
            calc_interpolation_idx(pme,atc,start,end,thread);
3758
        }
3759
    }
3760
#ifdef PME_TIME_THREADS
3761
    c1 = omp_cyc_end(c1);
3762
    cs1 += (double)c1;
3763
#endif
3764

    
3765
#ifdef PME_TIME_THREADS
3766
    c2 = omp_cyc_start();
3767
#endif
3768
#pragma omp parallel for num_threads(nthread) schedule(static)
3769
    for(thread=0; thread<nthread; thread++)
3770
    {
3771
        splinedata_t *spline;
3772
        pmegrid_t *grid;
3773

    
3774
        /* make local bsplines  */
3775
        if (grids == NULL || grids->nthread == 1)
3776
        {
3777
            spline = &atc->spline[0];
3778

    
3779
            spline->n = atc->n;
3780

    
3781
            grid = &grids->grid;
3782
        }
3783
        else
3784
        {
3785
            spline = &atc->spline[thread];
3786

    
3787
            make_thread_local_ind(atc,thread,spline);
3788

    
3789
            grid = &grids->grid_th[thread];
3790
        }
3791

    
3792
        if (bCalcSplines)
3793
        {
3794
            make_bsplines(spline->theta,spline->dtheta,pme->pme_order,
3795
                          atc->fractx,spline->n,spline->ind,atc->q,pme->bFEP);
3796
        }
3797

    
3798
        if (bSpread)
3799
        {
3800
            /* put local atoms on grid. */
3801
#ifdef PME_TIME_SPREAD
3802
            ct1a = omp_cyc_start();
3803
#endif
3804
            spread_q_bsplines_thread(grid,atc,spline,pme->spline_work);
3805

    
3806
            if (grids->nthread > 1)
3807
            {
3808
                copy_local_grid(pme,grids,thread,fftgrid);
3809
            }
3810
#ifdef PME_TIME_SPREAD
3811
            ct1a = omp_cyc_end(ct1a);
3812
            cs1a[thread] += (double)ct1a;
3813
#endif
3814
        }
3815
    }
3816
#ifdef PME_TIME_THREADS
3817
    c2 = omp_cyc_end(c2);
3818
    cs2 += (double)c2;
3819
#endif
3820

    
3821
    if (bSpread && grids->nthread > 1)
3822
    {
3823
#ifdef PME_TIME_THREADS
3824
        c3 = omp_cyc_start();
3825
#endif
3826
#pragma omp parallel for num_threads(grids->nthread) schedule(static)
3827
        for(thread=0; thread<grids->nthread; thread++)
3828
        {
3829
            reduce_threadgrid_overlap(pme,grids,thread,
3830
                                      fftgrid,
3831
                                      pme->overlap[0].sendbuf,
3832
                                      pme->overlap[1].sendbuf);
3833
#ifdef PRINT_PME_SENDBUF
3834
            print_sendbuf(pme,pme->overlap[0].sendbuf);
3835
#endif
3836
        }
3837
#ifdef PME_TIME_THREADS
3838
        c3 = omp_cyc_end(c3);
3839
        cs3 += (double)c3;
3840
#endif
3841

    
3842
        if (pme->nnodes > 1)
3843
        {
3844
            /* Communicate the overlapping part of the fftgrid */
3845
            sum_fftgrid_dd(pme,fftgrid);
3846
        }
3847
    }
3848

    
3849
#ifdef PME_TIME_THREADS
3850
    cnt++;
3851
    if (cnt % 20 == 0)
3852
    {
3853
        printf("idx %.2f spread %.2f red %.2f",
3854
               cs1*1e-9,cs2*1e-9,cs3*1e-9);
3855
#ifdef PME_TIME_SPREAD
3856
        for(thread=0; thread<nthread; thread++)
3857
            printf(" %.2f",cs1a[thread]*1e-9);
3858
#endif
3859
        printf("\n");
3860
    }
3861
#endif
3862
}
3863

    
3864

    
3865
static void dump_grid(FILE *fp,
3866
                      int sx,int sy,int sz,int nx,int ny,int nz,
3867
                      int my,int mz,const real *g)
3868
{
3869
    int x,y,z;
3870

    
3871
    for(x=0; x<nx; x++)
3872
    {
3873
        for(y=0; y<ny; y++)
3874
        {
3875
            for(z=0; z<nz; z++)
3876
            {
3877
                fprintf(fp,"%2d %2d %2d %6.3f\n",
3878
                        sx+x,sy+y,sz+z,g[(x*my + y)*mz + z]);
3879
            }
3880
        }
3881
    }
3882
}
3883

    
3884
static void dump_local_fftgrid(gmx_pme_t pme,const real *fftgrid)
3885
{
3886
    ivec local_fft_ndata,local_fft_offset,local_fft_size;
3887

    
3888
    gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3889
                                   local_fft_ndata,
3890
                                   local_fft_offset,
3891
                                   local_fft_size);
3892

    
3893
    dump_grid(stderr,
3894
              pme->pmegrid_start_ix,
3895
              pme->pmegrid_start_iy,
3896
              pme->pmegrid_start_iz,
3897
              pme->pmegrid_nx-pme->pme_order+1,
3898
              pme->pmegrid_ny-pme->pme_order+1,
3899
              pme->pmegrid_nz-pme->pme_order+1,
3900
              local_fft_size[YY],
3901
              local_fft_size[ZZ],
3902
              fftgrid);
3903
}
3904

    
3905

    
3906
void gmx_pme_calc_energy(gmx_pme_t pme,int n,rvec *x,real *q,real *V)
3907
{
3908
    pme_atomcomm_t *atc;
3909
    pmegrids_t *grid;
3910

    
3911
    if (pme->nnodes > 1)
3912
    {
3913
        gmx_incons("gmx_pme_calc_energy called in parallel");
3914
    }
3915
    if (pme->bFEP > 1)
3916
    {
3917
        gmx_incons("gmx_pme_calc_energy with free energy");
3918
    }
3919

    
3920
    atc = &pme->atc_energy;
3921
    atc->nthread   = 1;
3922
    if (atc->spline == NULL)
3923
    {
3924
        snew(atc->spline,atc->nthread);
3925
    }
3926
    atc->nslab     = 1;
3927
    atc->bSpread   = TRUE;
3928
    atc->pme_order = pme->pme_order;
3929
    atc->n         = n;
3930
    pme_realloc_atomcomm_things(atc);
3931
    atc->x         = x;
3932
    atc->q         = q;
3933

    
3934
    /* We only use the A-charges grid */
3935
    grid = &pme->pmegridA;
3936

    
3937
    spread_on_grid(pme,atc,NULL,TRUE,FALSE,pme->fftgridA);
3938

    
3939
    *V = gather_energy_bsplines(pme,grid->grid.grid,atc);
3940
}
3941

    
3942

    
3943
static void reset_pmeonly_counters(t_commrec *cr,gmx_wallcycle_t wcycle,
3944
        t_nrnb *nrnb,t_inputrec *ir, gmx_large_int_t step_rel)
3945
{
3946
    /* Reset all the counters related to performance over the run */
3947
    wallcycle_stop(wcycle,ewcRUN);
3948
    wallcycle_reset_all(wcycle);
3949
    init_nrnb(nrnb);
3950
    ir->init_step += step_rel;
3951
    ir->nsteps    -= step_rel;
3952
    wallcycle_start(wcycle,ewcRUN);
3953
}
3954

    
3955

    
3956
int gmx_pmeonly(gmx_pme_t pme,
3957
                t_commrec *cr,    t_nrnb *nrnb,
3958
                gmx_wallcycle_t wcycle,
3959
                real ewaldcoeff,  gmx_bool bGatherOnly,
3960
                t_inputrec *ir)
3961
{
3962
    gmx_pme_pp_t pme_pp;
3963
    int  natoms;
3964
    matrix box;
3965
    rvec *x_pp=NULL,*f_pp=NULL;
3966
    real *chargeA=NULL,*chargeB=NULL;
3967
    real lambda=0;
3968
    int  maxshift_x=0,maxshift_y=0;
3969
    real energy,dvdlambda;
3970
    matrix vir;
3971
    float cycles;
3972
    int  count;
3973
    gmx_bool bEnerVir;
3974
    gmx_large_int_t step,step_rel;
3975

    
3976

    
3977
    pme_pp = gmx_pme_pp_init(cr);
3978

    
3979
    init_nrnb(nrnb);
3980

    
3981
    count = 0;
3982
    do /****** this is a quasi-loop over time steps! */
3983
    {
3984
        /* Domain decomposition */
3985
        natoms = gmx_pme_recv_q_x(pme_pp,
3986
                                  &chargeA,&chargeB,box,&x_pp,&f_pp,
3987
                                  &maxshift_x,&maxshift_y,
3988
                                  &pme->bFEP,&lambda,
3989
                                  &bEnerVir,
3990
                                  &step);
3991

    
3992
        if (natoms == -1) {
3993
            /* We should stop: break out of the loop */
3994
            break;
3995
        }
3996

    
3997
        step_rel = step - ir->init_step;
3998

    
3999
        if (count == 0)
4000
            wallcycle_start(wcycle,ewcRUN);
4001

    
4002
        wallcycle_start(wcycle,ewcPMEMESH);
4003

    
4004
        dvdlambda = 0;
4005
        clear_mat(vir);
4006
        gmx_pme_do(pme,0,natoms,x_pp,f_pp,chargeA,chargeB,box,
4007
                   cr,maxshift_x,maxshift_y,nrnb,wcycle,vir,ewaldcoeff,
4008
                   &energy,lambda,&dvdlambda,
4009
                   GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
4010

    
4011
        cycles = wallcycle_stop(wcycle,ewcPMEMESH);
4012

    
4013
        gmx_pme_send_force_vir_ener(pme_pp,
4014
                                    f_pp,vir,energy,dvdlambda,
4015
                                    cycles);
4016

    
4017
        count++;
4018

    
4019
        if (step_rel == wcycle_get_reset_counters(wcycle))
4020
        {
4021
            /* Reset all the counters related to performance over the run */
4022
            reset_pmeonly_counters(cr,wcycle,nrnb,ir,step_rel);
4023
            wcycle_set_reset_counters(wcycle, 0);
4024
        }
4025

    
4026
    } /***** end of quasi-loop, we stop with the break above */
4027
    while (TRUE);
4028

    
4029
    return 0;
4030
}
4031

    
4032
int gmx_pme_do(gmx_pme_t pme,
4033
               int start,       int homenr,
4034
               rvec x[],        rvec f[],
4035
               real *chargeA,   real *chargeB,
4036
               matrix box, t_commrec *cr,
4037
               int  maxshift_x, int maxshift_y,
4038
               t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4039
               matrix vir,      real ewaldcoeff,
4040
               real *energy,    real lambda,
4041
               real *dvdlambda, int flags)
4042
{
4043
    int     q,d,i,j,ntot,npme;
4044
    int     nx,ny,nz;
4045
    int     n_d,local_ny;
4046
    pme_atomcomm_t *atc=NULL;
4047
    pmegrids_t *pmegrid=NULL;
4048
    real    *grid=NULL;
4049
    real    *ptr;
4050
    rvec    *x_d,*f_d;
4051
    real    *charge=NULL,*q_d;
4052
    real    energy_AB[2];
4053
    matrix  vir_AB[2];
4054
    gmx_bool bClearF;
4055
    gmx_parallel_3dfft_t pfft_setup;
4056
    real *  fftgrid;
4057
    t_complex * cfftgrid;
4058
    int     thread;
4059
    const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4060
    const gmx_bool bCalcF = flags & GMX_PME_CALC_F;
4061

    
4062
    assert(pme->nnodes > 0);
4063
    assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4064

    
4065
    if (pme->nnodes > 1) {
4066
        atc = &pme->atc[0];
4067
        atc->npd = homenr;
4068
        if (atc->npd > atc->pd_nalloc) {
4069
            atc->pd_nalloc = over_alloc_dd(atc->npd);
4070
            srenew(atc->pd,atc->pd_nalloc);
4071
        }
4072
        atc->maxshift = (atc->dimind==0 ? maxshift_x : maxshift_y);
4073
    }
4074
    else
4075
    {
4076
        /* This could be necessary for TPI */
4077
        pme->atc[0].n = homenr;
4078
    }
4079

    
4080
    for(q=0; q<(pme->bFEP ? 2 : 1); q++) {
4081
        if (q == 0) {
4082
            pmegrid = &pme->pmegridA;
4083
            fftgrid = pme->fftgridA;
4084
            cfftgrid = pme->cfftgridA;
4085
            pfft_setup = pme->pfft_setupA;
4086
            charge = chargeA+start;
4087
        } else {
4088
            pmegrid = &pme->pmegridB;
4089
            fftgrid = pme->fftgridB;
4090
            cfftgrid = pme->cfftgridB;
4091
            pfft_setup = pme->pfft_setupB;
4092
            charge = chargeB+start;
4093
        }
4094
        grid = pmegrid->grid.grid;
4095
        /* Unpack structure */
4096
        if (debug) {
4097
            fprintf(debug,"PME: nnodes = %d, nodeid = %d\n",
4098
                    cr->nnodes,cr->nodeid);
4099
            fprintf(debug,"Grid = %p\n",(void*)grid);
4100
            if (grid == NULL)
4101
                gmx_fatal(FARGS,"No grid!");
4102
        }
4103
        where();
4104

    
4105
        m_inv_ur0(box,pme->recipbox);
4106

    
4107
        if (pme->nnodes == 1) {
4108
            atc = &pme->atc[0];
4109
            if (DOMAINDECOMP(cr)) {
4110
                atc->n = homenr;
4111
                pme_realloc_atomcomm_things(atc);
4112
            }
4113
            atc->x = x;
4114
            atc->q = charge;
4115
            atc->f = f;
4116
        } else {
4117
            wallcycle_start(wcycle,ewcPME_REDISTXF);
4118
            for(d=pme->ndecompdim-1; d>=0; d--)
4119
            {
4120
                if (d == pme->ndecompdim-1)
4121
                {
4122
                    n_d = homenr;
4123
                    x_d = x + start;
4124
                    q_d = charge;
4125
                }
4126
                else
4127
                {
4128
                    n_d = pme->atc[d+1].n;
4129
                    x_d = atc->x;
4130
                    q_d = atc->q;
4131
                }
4132
                atc = &pme->atc[d];
4133
                atc->npd = n_d;
4134
                if (atc->npd > atc->pd_nalloc) {
4135
                    atc->pd_nalloc = over_alloc_dd(atc->npd);
4136
                    srenew(atc->pd,atc->pd_nalloc);
4137
                }
4138
                atc->maxshift = (atc->dimind==0 ? maxshift_x : maxshift_y);
4139
                pme_calc_pidx_wrapper(n_d,pme->recipbox,x_d,atc);
4140
                where();
4141

    
4142
                GMX_BARRIER(cr->mpi_comm_mygroup);
4143
                /* Redistribute x (only once) and qA or qB */
4144
                if (DOMAINDECOMP(cr)) {
4145
                    dd_pmeredist_x_q(pme, n_d, q==0, x_d, q_d, atc);
4146
                } else {
4147
                    pmeredist_pd(pme, TRUE, n_d, q==0, x_d, q_d, atc);
4148
                }
4149
            }
4150
            where();
4151

    
4152
            wallcycle_stop(wcycle,ewcPME_REDISTXF);
4153
        }
4154

    
4155
        if (debug)
4156
            fprintf(debug,"Node= %6d, pme local particles=%6d\n",
4157
                    cr->nodeid,atc->n);
4158

    
4159
        if (flags & GMX_PME_SPREAD_Q)
4160
        {
4161
            wallcycle_start(wcycle,ewcPME_SPREADGATHER);
4162

    
4163
            /* Spread the charges on a grid */
4164
            GMX_MPE_LOG(ev_spread_on_grid_start);
4165

    
4166
            /* Spread the charges on a grid */
4167
            spread_on_grid(pme,&pme->atc[0],pmegrid,q==0,TRUE,fftgrid);
4168
            GMX_MPE_LOG(ev_spread_on_grid_finish);
4169

    
4170
            if (q == 0)
4171
            {
4172
                inc_nrnb(nrnb,eNR_WEIGHTS,DIM*atc->n);
4173
            }
4174
            inc_nrnb(nrnb,eNR_SPREADQBSP,
4175
                     pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4176

    
4177
            if (pme->nthread == 1)
4178
            {
4179
                wrap_periodic_pmegrid(pme,grid);
4180

    
4181
                /* sum contributions to local grid from other nodes */
4182
#ifdef GMX_MPI
4183
                if (pme->nnodes > 1)
4184
                {
4185
                    GMX_BARRIER(cr->mpi_comm_mygroup);
4186
                    gmx_sum_qgrid_dd(pme,grid,GMX_SUM_QGRID_FORWARD);
4187
                    where();
4188
                }
4189
#endif
4190

    
4191
                copy_pmegrid_to_fftgrid(pme,grid,fftgrid);
4192
            }
4193

    
4194
            wallcycle_stop(wcycle,ewcPME_SPREADGATHER);
4195

    
4196
            /*
4197
            dump_local_fftgrid(pme,fftgrid);
4198
            exit(0);
4199
            */
4200
        }
4201

    
4202
        /* Here we start a large thread parallel region */
4203
#pragma omp parallel for num_threads(pme->nthread) schedule(static)
4204
        for(thread=0; thread<pme->nthread; thread++)
4205
        {
4206
            if (flags & GMX_PME_SOLVE)
4207
            {
4208
                int loop_count;
4209

    
4210
                /* do 3d-fft */
4211
                if (thread == 0)
4212
                {
4213
                    GMX_BARRIER(cr->mpi_comm_mygroup);
4214
                    GMX_MPE_LOG(ev_gmxfft3d_start);
4215
                    wallcycle_start(wcycle,ewcPME_FFT);
4216
                }
4217
                gmx_parallel_3dfft_execute(pfft_setup,GMX_FFT_REAL_TO_COMPLEX,
4218
                                           fftgrid,cfftgrid,thread,wcycle);
4219
                if (thread == 0)
4220
                {
4221
                    wallcycle_stop(wcycle,ewcPME_FFT);
4222
                    GMX_MPE_LOG(ev_gmxfft3d_finish);
4223
                }
4224
                where();
4225

    
4226
                /* solve in k-space for our local cells */
4227
                if (thread == 0)
4228
                {
4229
                    GMX_BARRIER(cr->mpi_comm_mygroup);
4230
                    GMX_MPE_LOG(ev_solve_pme_start);
4231
                    wallcycle_start(wcycle,ewcPME_SOLVE);
4232
                }
4233
                loop_count =
4234
                    solve_pme_yzx(pme,cfftgrid,ewaldcoeff,
4235
                                  box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4236
                                  bCalcEnerVir,
4237
                                  pme->nthread,thread);
4238
                if (thread == 0)
4239
                {
4240
                    wallcycle_stop(wcycle,ewcPME_SOLVE);
4241
                    where();
4242
                    GMX_MPE_LOG(ev_solve_pme_finish);
4243
                    inc_nrnb(nrnb,eNR_SOLVEPME,loop_count);
4244
                }
4245
            }
4246

    
4247
            if (bCalcF)
4248
            {
4249
                /* do 3d-invfft */
4250
                if (thread == 0)
4251
                {
4252
                    GMX_BARRIER(cr->mpi_comm_mygroup);
4253
                    GMX_MPE_LOG(ev_gmxfft3d_start);
4254
                    where();
4255
                    wallcycle_start(wcycle,ewcPME_FFT);
4256
                }
4257
                gmx_parallel_3dfft_execute(pfft_setup,GMX_FFT_COMPLEX_TO_REAL,
4258
                                           cfftgrid,fftgrid,thread,wcycle);
4259
                if (thread == 0)
4260
                {
4261
                    wallcycle_stop(wcycle,ewcPME_FFT);
4262

    
4263
                    where();
4264
                    GMX_MPE_LOG(ev_gmxfft3d_finish);
4265

    
4266
                    if (pme->nodeid == 0)
4267
                    {
4268
                        ntot = pme->nkx*pme->nky*pme->nkz;
4269
                        npme  = ntot*log((real)ntot)/log(2.0);
4270
                        inc_nrnb(nrnb,eNR_FFT,2*npme);
4271
                    }
4272

    
4273
                    wallcycle_start(wcycle,ewcPME_SPREADGATHER);
4274
                }
4275

    
4276
                copy_fftgrid_to_pmegrid(pme,fftgrid,grid,pme->nthread,thread);
4277
            }
4278
        }
4279
        /* End of thread parallel section.
4280
         * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4281
         */
4282

    
4283
        if (bCalcF)
4284
        {
4285
            /* distribute local grid to all nodes */
4286
#ifdef GMX_MPI
4287
            if (pme->nnodes > 1) {
4288
                GMX_BARRIER(cr->mpi_comm_mygroup);
4289
                gmx_sum_qgrid_dd(pme,grid,GMX_SUM_QGRID_BACKWARD);
4290
            }
4291
#endif
4292
            where();
4293

    
4294
            unwrap_periodic_pmegrid(pme,grid);
4295

    
4296
            /* interpolate forces for our local atoms */
4297
            GMX_BARRIER(cr->mpi_comm_mygroup);
4298
            GMX_MPE_LOG(ev_gather_f_bsplines_start);
4299

    
4300
            where();
4301

    
4302
            /* If we are running without parallelization,
4303
             * atc->f is the actual force array, not a buffer,
4304
             * therefore we should not clear it.
4305
             */
4306
            bClearF = (q == 0 && PAR(cr));
4307
#pragma omp parallel for num_threads(pme->nthread) schedule(static)
4308
            for(thread=0; thread<pme->nthread; thread++)
4309
            {
4310
                gather_f_bsplines(pme,grid,bClearF,atc,
4311
                                  &atc->spline[thread],
4312
                                  pme->bFEP ? (q==0 ? 1.0-lambda : lambda) : 1.0);
4313
            }
4314

    
4315
            where();
4316

    
4317
            GMX_MPE_LOG(ev_gather_f_bsplines_finish);
4318

    
4319
            inc_nrnb(nrnb,eNR_GATHERFBSP,
4320
                     pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4321
            wallcycle_stop(wcycle,ewcPME_SPREADGATHER);
4322
        }
4323

    
4324
        if (bCalcEnerVir)
4325
        {
4326
            /* This should only be called on the master thread
4327
             * and after the threads have synchronized.
4328
             */
4329
            get_pme_ener_vir(pme,pme->nthread,&energy_AB[q],vir_AB[q]);
4330
        }
4331
    } /* of q-loop */
4332

    
4333
    if (bCalcF && pme->nnodes > 1) {
4334
        wallcycle_start(wcycle,ewcPME_REDISTXF);
4335
        for(d=0; d<pme->ndecompdim; d++)
4336
        {
4337
            atc = &pme->atc[d];
4338
            if (d == pme->ndecompdim - 1)
4339
            {
4340
                n_d = homenr;
4341
                f_d = f + start;
4342
            }
4343
            else
4344
            {
4345
                n_d = pme->atc[d+1].n;
4346
                f_d = pme->atc[d+1].f;
4347
            }
4348
            GMX_BARRIER(cr->mpi_comm_mygroup);
4349
            if (DOMAINDECOMP(cr)) {
4350
                dd_pmeredist_f(pme,atc,n_d,f_d,
4351
                               d==pme->ndecompdim-1 && pme->bPPnode);
4352
            } else {
4353
                pmeredist_pd(pme, FALSE, n_d, TRUE, f_d, NULL, atc);
4354
            }
4355
        }
4356

    
4357
        wallcycle_stop(wcycle,ewcPME_REDISTXF);
4358
    }
4359
    where();
4360

    
4361
    if (bCalcEnerVir)
4362
    {
4363
        if (!pme->bFEP) {
4364
            *energy = energy_AB[0];
4365
            m_add(vir,vir_AB[0],vir);
4366
        } else {
4367
            *energy = (1.0-lambda)*energy_AB[0] + lambda*energy_AB[1];
4368
            *dvdlambda += energy_AB[1] - energy_AB[0];
4369
            for(i=0; i<DIM; i++)
4370
            {
4371
                for(j=0; j<DIM; j++)
4372
                {
4373
                    vir[i][j] += (1.0-lambda)*vir_AB[0][i][j] + 
4374
                        lambda*vir_AB[1][i][j];
4375
                }
4376
            }
4377
        }
4378
    }
4379
    else
4380
    {
4381
        *energy = 0;
4382
    }
4383

    
4384
    if (debug)
4385
    {
4386
        fprintf(debug,"PME mesh energy: %g\n",*energy);
4387
    }
4388

    
4389
    return 0;
4390
}