Project

General

Profile

« Previous | Next » 

Revision 62b9f665

Added by Leszek Koltunski 12 months ago

speedup

View differences:

src/main/java/org/distorted/objectlib/tablebases/TBCuboid323.java
10 10
package org.distorted.objectlib.tablebases;
11 11

  
12 12
import org.distorted.library.type.Static3D;
13
import org.distorted.library.type.Static4D;
14 13
import org.distorted.objectlib.R;
15 14
import org.distorted.objectlib.helpers.OperatingSystemInterface;
16 15

  
......
202 201
// purely for speedup
203 202

  
204 203
  @Override
205
  int computeRow(float[] pos, int quat, int axisIndex)
204
  void computeRow(float[] pos, int quat, int[] output)
206 205
    {
207
    Static4D q = mQuats[quat];
208
    float qx = q.get0();
209
    float qy = q.get1();
210
    float qz = q.get2();
211
    float qw = q.get3();
206
    float[] q = mQuats[quat];
207
    float qx = q[0];
208
    float qy = q[1];
209
    float qz = q[2];
210
    float qw = q[3];
212 211
    float rx = pos[0];
213 212
    float ry = pos[1];
214 213
    float rz = pos[2];
......
218 217
    mTmp[2] = qz + rz*qw - ry*qx + rx*qy;
219 218
    mTmp[3] = qw - rz*qz - ry*qy - rx*qx;
220 219

  
221
    switch(axisIndex)
222
      {
223
      case 0: float res0 = qw*mTmp[0] + qz*mTmp[1] - qy*mTmp[2] - qx*mTmp[3];
224
              if( res0<-0.5f ) return 1;
225
              if( res0< 0.5f ) return 2;
226
              return 4;
227
      case 1: float res1 = qw*mTmp[1] - qz*mTmp[0] - qy*mTmp[3] + qx*mTmp[2];
228
              return (res1<0 ? 1:2);
229
      case 2: float res2 = qw*mTmp[2] - qz*mTmp[3] + qy*mTmp[0] - qx*mTmp[1];
230
              if( res2<-0.5f ) return 1;
231
              if( res2< 0.5f ) return 2;
232
              return 4;
233
      }
220
    float x = qw*mTmp[0] + qz*mTmp[1] - qy*mTmp[2] - qx*mTmp[3];
221
    float y = qw*mTmp[1] - qz*mTmp[0] - qy*mTmp[3] + qx*mTmp[2];
222
    float z = qw*mTmp[2] - qz*mTmp[3] + qy*mTmp[0] - qx*mTmp[1];
234 223

  
235
    return 0;
224
    output[0] = (x<-0.5f ? 1 : (x<0.5f ? 2:4));
225
    output[1] = (y<0 ? 1:2);
226
    output[2] = (z<-0.5f ? 1 : (z<0.5f ? 2:4));
236 227
    }
237 228

  
238 229
///////////////////////////////////////////////////////////////////////////////////////////////////
src/main/java/org/distorted/objectlib/tablebases/TablebasesAbstract.java
39 39
  Tablebase mTablebase;
40 40
  boolean mInitialized;
41 41

  
42
  final Static4D[] mQuats;
42
  final float[][] mQuats;
43 43
  final int mScalingFactor;
44 44
  final int mNumAxis;
45 45
  final float[][] mPosition;
......
74 74
    mNumAxis = mAxis.length;
75 75
    mNumLayers = new int[mNumAxis];
76 76
    for(int i=0; i<mNumAxis; i++) mNumLayers[i] = mAngles[i].length;
77
    mQuats = QuatGroupGenerator.computeGroup(mAxis,mAngles);
78
    mNumQuats = mQuats.length;
77

  
78
    Static4D[] quats = QuatGroupGenerator.computeGroup(mAxis,mAngles);
79
    mNumQuats = quats.length;
80
    mQuats = new float[mNumQuats][];
81

  
82
    for(int i=0; i<mNumQuats; i++)
83
      {
84
      Static4D q = quats[i];
85
      mQuats[i] = new float[] {q.get0(),q.get1(),q.get2(),q.get3()};
86
      }
87

  
79 88
    mPosition = getPosition();
80 89
    mNumCubits = mPosition.length;
81 90
    mRotatable = getRotatable();
......
135 144

  
136 145
///////////////////////////////////////////////////////////////////////////////////////////////////
137 146

  
138
  int computeRow(float[] pos, int quat, int axisIndex)
147
  void computeRow(float[] pos, int quat, int[] output)
139 148
    {
140
    int ret=0;
141 149
    int len = pos.length/3;
142
    Static3D axis = mAxis[axisIndex];
143
    float axisX = axis.get0();
144
    float axisY = axis.get1();
145
    float axisZ = axis.get2();
146
    float casted;
147
    Static4D q = mQuats[quat];
150
    float[] q = mQuats[quat];
151
    int num = output.length;
148 152

  
149 153
    for(int i=0; i<len; i++)
150 154
      {
151 155
      QuatHelper.rotateVectorByQuat(mTmp,pos[3*i],pos[3*i+1],pos[3*i+2],1.0f,q);
152
      casted = mTmp[0]*axisX + mTmp[1]*axisY + mTmp[2]*axisZ;
153
      ret |= computeSingleRow(axisIndex,casted);
154
      }
155 156

  
156
    return ret;
157
      for(int j=0; j<num; j++)
158
        {
159
        output[j] = 0;
160
        Static3D axis = mAxis[j];
161
        float axisX = axis.get0();
162
        float axisY = axis.get1();
163
        float axisZ = axis.get2();
164
        float casted = mTmp[0]*axisX + mTmp[1]*axisY + mTmp[2]*axisZ;
165
        output[j] |= computeSingleRow(j,casted);
166
        }
167
      }
157 168
    }
158 169

  
159 170
///////////////////////////////////////////////////////////////////////////////////////////////////
......
173 184
///////////////////////////////////////////////////////////////////////////////////////////////////
174 185
// remember about the double cover or unit quaternions!
175 186

  
176
  public static int mulQuat(int q1, int q2, Static4D[] quats)
187
  public static int mulQuat(int q1, int q2, float[][] quats)
177 188
    {
178 189
    int numQuats = quats.length;
179
    Static4D result = QuatHelper.quatMultiply(quats[q1],quats[q2]);
190
    float[] result = QuatHelper.quatMultiply(quats[q1],quats[q2]);
180 191

  
181
    float rX = result.get0();
182
    float rY = result.get1();
183
    float rZ = result.get2();
184
    float rW = result.get3();
192
    float rX = result[0];
193
    float rY = result[1];
194
    float rZ = result[2];
195
    float rW = result[3];
185 196

  
186 197
    final float MAX_ERROR = 0.1f;
187 198
    float dX,dY,dZ,dW;
188 199

  
189 200
    for(int i=0; i<numQuats; i++)
190 201
      {
191
      dX = quats[i].get0() - rX;
192
      dY = quats[i].get1() - rY;
193
      dZ = quats[i].get2() - rZ;
194
      dW = quats[i].get3() - rW;
202
      float[] q = quats[i];
203
      dX = q[0] - rX;
204
      dY = q[1] - rY;
205
      dZ = q[2] - rZ;
206
      dW = q[3] - rW;
195 207

  
196 208
      if( dX<MAX_ERROR && dX>-MAX_ERROR &&
197 209
          dY<MAX_ERROR && dY>-MAX_ERROR &&
198 210
          dZ<MAX_ERROR && dZ>-MAX_ERROR &&
199 211
          dW<MAX_ERROR && dW>-MAX_ERROR  ) return i;
200 212

  
201
      dX = quats[i].get0() + rX;
202
      dY = quats[i].get1() + rY;
203
      dZ = quats[i].get2() + rZ;
204
      dW = quats[i].get3() + rW;
213
      dX = q[0] + rX;
214
      dY = q[1] + rY;
215
      dZ = q[2] + rZ;
216
      dW = q[3] + rW;
205 217

  
206 218
      if( dX<MAX_ERROR && dX>-MAX_ERROR &&
207 219
          dY<MAX_ERROR && dY>-MAX_ERROR &&
......
269 281
    byte newLevel = (byte)(level+1);
270 282
    int quatBasis = 0;
271 283

  
272
    for(int ax=0; ax<mNumAxis; ax++)
273
      for(int cubit=0; cubit<mNumCubits; cubit++)
274
        mRotRow[cubit][ax] = computeRow(mPosition[cubit],quats[cubit],ax);
284
    for(int cubit=0; cubit<mNumCubits; cubit++) computeRow(mPosition[cubit],quats[cubit],mRotRow[cubit]);
275 285

  
276 286
    for(int ax=0; ax<mNumAxis; ax++)
277 287
      {
......
472 482
      data[2]=1;
473 483
      data[3]=1;
474 484

  
475
      for(int ax=0; ax<mNumAxis; ax++)
476
        for(int cubit=0; cubit<mNumCubits; cubit++)
477
          mRotRow[cubit][ax] = computeRow(mPosition[cubit],quats[cubit],ax);
485
      for(int cubit=0; cubit<mNumCubits; cubit++) computeRow(mPosition[cubit],quats[cubit],mRotRow[cubit]);
478 486

  
479 487
      for(int s=0; s<mScalingFactor && !found; s++)
480 488
        {
......
578 586
    data[2]=1;
579 587
    data[3]=1;
580 588

  
581
    for(int ax=0; ax<mNumAxis; ax++)
582
      for(int cubit=0; cubit<mNumCubits; cubit++)
583
        mRotRow[cubit][ax] = computeRow(mPosition[cubit],quats[cubit],ax);
589
    for(int cubit=0; cubit<mNumCubits; cubit++) computeRow(mPosition[cubit],quats[cubit],mRotRow[cubit]);
584 590

  
585 591
    for(int s=0; s<mScalingFactor; s++)
586 592
      {
src/main/java/org/distorted/objectlib/tablebases/TablebasesPruning.java
174 174
    move[2]=1;
175 175
    move[3]=1;
176 176

  
177
    for(int ax=0; ax<mNumAxis; ax++)
178
      for(int cubit=0; cubit<mNumCubits; cubit++)
179
        mRotRow[cubit][ax] = computeRow(mPosition[cubit],quats[cubit],ax);
177
    for(int cubit=0; cubit<mNumCubits; cubit++) computeRow(mPosition[cubit],quats[cubit],mRotRow[cubit]);
180 178

  
181 179
    for(int s=0; s<mScalingFactor; s++)
182 180
      {
......
287 285
    move[2]=1;
288 286
    move[3]=1;
289 287

  
290
    for(int ax=0; ax<mNumAxis; ax++)
291
      for(int cubit=0; cubit<mNumCubits; cubit++)
292
        rotRow[cubit][ax] = computeRow(mPosition[cubit],quats[cubit],ax);
288
    for(int cubit=0; cubit<mNumCubits; cubit++) computeRow(mPosition[cubit],quats[cubit],rotRow[cubit]);
293 289

  
294 290
    for(int s=0; s<mScalingFactor; s++)
295 291
      {
......
377 373
    move[2]=1;
378 374
    move[3]=1;
379 375

  
380
    for(int ax=0; ax<mNumAxis; ax++)
381
      for(int cubit=0; cubit<mNumCubits; cubit++)
382
        rotRow[cubit][ax] = computeRow(mPosition[cubit],quats[cubit],ax);
376
    for(int cubit=0; cubit<mNumCubits; cubit++) computeRow(mPosition[cubit],quats[cubit],rotRow[cubit]);
383 377

  
384 378
    for(int s=0; s<mScalingFactor; s++)
385 379
      {

Also available in: Unified diff