Project

General

Profile

Download (9.87 KB) Statistics
| Branch: | Revision:

distorted-objectlib / src / main / java / org / distorted / objectlib / tablebases / PruningTable.java @ 15d1f6ad

1
///////////////////////////////////////////////////////////////////////////////////////////////////
2
// Copyright 2023 Leszek Koltunski                                                               //
3
//                                                                                               //
4
// This file is part of Magic Cube.                                                              //
5
//                                                                                               //
6
// Magic Cube is proprietary software licensed under an EULA which you should have received      //
7
// along with the code. If not, check https://distorted.org/magic/License-Magic-Cube.html        //
8
///////////////////////////////////////////////////////////////////////////////////////////////////
9

    
10
package org.distorted.objectlib.tablebases;
11

    
12
///////////////////////////////////////////////////////////////////////////////////////////////////
13

    
14
class PruningTable
15
{
16
  private static final int HEADER_SIZE = 5;
17
  private byte[] mData;
18
  private int mNumBits;
19
  private int mBucketBytes;
20
  private int mNumBuckets;
21
  private int mLevel;
22

    
23
///////////////////////////////////////////////////////////////////////////////////////////////////
24

    
25
  private static String printByte(byte b)
26
    {
27
    return String.format("%8s", Integer.toBinaryString(b & 0xFF)).replace(' ', '0');
28
    }
29

    
30
///////////////////////////////////////////////////////////////////////////////////////////////////
31

    
32
  private int numBytesForIndices(int total)
33
    {
34
    int bytes = 0;
35

    
36
    while(total!=0)
37
      {
38
      total>>=8;
39
      bytes++;
40
      }
41

    
42
    return bytes;
43
    }
44

    
45
///////////////////////////////////////////////////////////////////////////////////////////////////
46

    
47
  private void writeHeader()
48
    {
49
    mData[0] = (byte)mLevel;
50
    mData[1] = (byte)mBucketBytes;
51
    mData[2] = (byte)mNumBits;
52
    mData[3] = (byte)(mNumBuckets/256);
53
    mData[4] = (byte)(mNumBuckets%256);
54
    }
55

    
56
///////////////////////////////////////////////////////////////////////////////////////////////////
57

    
58
  private int getBucketOffset(int bucket)
59
    {
60
    if( bucket>=mNumBuckets ) return mData.length;
61

    
62
    int ret = 0;
63
    int start = HEADER_SIZE + bucket*mBucketBytes;
64

    
65
    for(int i=0; i<mBucketBytes; i++)
66
      {
67
      ret<<=8;
68
      ret += ( ((int)(mData[start+i]))&0xff);
69
      }
70

    
71
    return HEADER_SIZE+mNumBuckets*mBucketBytes+ret;
72
    }
73

    
74
///////////////////////////////////////////////////////////////////////////////////////////////////
75

    
76
  private void writeBucketPointers(int[] bucket)
77
    {
78
    int start=HEADER_SIZE;
79

    
80
    for( int b : bucket )
81
      {
82
      for( int i=0; i<mBucketBytes; i++)
83
        {
84
        mData[start+mBucketBytes-1-i] = (byte)(b%256);
85
        b>>=8;
86
        }
87
      start += mBucketBytes;
88
      }
89
    }
90

    
91
///////////////////////////////////////////////////////////////////////////////////////////////////
92

    
93
  private int retrieveData(int index, int entryInBucket)
94
    {
95
    int bits = entryInBucket*mNumBits;
96
    int bytes = index + bits/8;
97
    int ret=0;
98

    
99
    switch(mNumBits)
100
      {
101
      case  4: ret = (bits%8)!=0 ? (mData[bytes]&0xf) : ((mData[bytes]>>>4)&0xf);
102
               break;
103
      case  8: ret = (mData[bytes]&0xff);
104
               break;
105
      case 12: ret = (bits%8)!=0 ? ((mData[bytes]&0xf)*256 + (mData[bytes+1]&0xff)) : ( (mData[bytes]&0xff)*16 + ((mData[bytes+1]>>>4)&0xf));
106
               break;
107
      case 16: ret = (mData[bytes]&0xff)*256 + (mData[bytes+1]&0xff);
108
               break;
109
      }
110

    
111
    return ret;
112
    }
113

    
114
///////////////////////////////////////////////////////////////////////////////////////////////////
115

    
116
  private void insertData(int index, int entryInBucket, int value)
117
    {
118
    int bits = entryInBucket*mNumBits;
119
    int bytes = index + bits/8;
120

    
121
    switch(mNumBits)
122
      {
123
      case  4: if( (bits%8) !=0 ) mData[bytes] = (byte)((mData[bytes]&(~0xf)) + value  );
124
               else               mData[bytes] = (byte)((value<<4) + (mData[bytes]&0xf));
125
               break;
126
      case  8: mData[bytes] = (byte)value;
127
               break;
128
      case 12: if( (bits%8) !=0 )
129
                 {
130
                 mData[bytes]  = (byte)((mData[bytes]&(~0xf)) + (value>>>8));
131
                 mData[bytes+1]= (byte)value;
132
                 }
133
               else
134
                 {
135
                 mData[bytes]  = (byte)(value>>>4);
136
                 mData[bytes+1]= (byte)((byte)(value<<4) + (mData[bytes+1]&( 0xf)));
137
                 }
138
               break;
139
      case 16: mData[bytes]   = (byte)(value>>>8);
140
               mData[bytes+1] = (byte)(value&(0xff));
141
      }
142
    }
143

    
144
///////////////////////////////////////////////////////////////////////////////////////////////////
145

    
146
  private void writeBuckets(Tablebase table)
147
    {
148
    int mask = (1<<mNumBits);
149
    int prevBucket = -1;
150
    int entryInBucket = 0;
151
    int index = 0;
152
    int size = table.getSize();
153

    
154
    for(int i=0; i<size; i++)
155
      if( table.retrieveUnpacked(i)==mLevel )
156
        {
157
        int toBeWritten = (i%mask);
158
        int currBucket = (i>>mNumBits);
159

    
160
        if( currBucket!=prevBucket )
161
          {
162
          index = getBucketOffset(currBucket);
163
          entryInBucket = 0;
164
          }
165
        else
166
          {
167
          entryInBucket ++;
168
          }
169

    
170
        prevBucket = currBucket;
171
        insertData(index,entryInBucket,toBeWritten);
172
        }
173
    }
174

    
175
///////////////////////////////////////////////////////////////////////////////////////////////////
176

    
177
  private int retNumPositions(Tablebase table, int level)
178
    {
179
    int size=0, tableSize = table.getSize();
180

    
181
    for(int i=0; i<tableSize; i++)
182
      if( table.retrieveUnpacked(i)==level ) size++;
183

    
184
    return size;
185
    }
186

    
187
///////////////////////////////////////////////////////////////////////////////////////////////////
188

    
189
  private int computeApproximateSize(int tableSize, int positionSize, int numBits)
190
    {
191
    int entrySize = (1<<numBits);
192
    int numBuckets = (tableSize+entrySize-1)/entrySize;
193
    int overhang = (numBits==4 || numBits==12) ? numBuckets/4 : 0;
194
    int totalBytesForTable = (positionSize*numBits)/8 + overhang;
195
    int bucketBytes = numBytesForIndices(totalBytesForTable);
196
    int totalBytesForBuckets = numBuckets*bucketBytes;
197

    
198
    return HEADER_SIZE + totalBytesForBuckets + totalBytesForTable;
199
    }
200

    
201
///////////////////////////////////////////////////////////////////////////////////////////////////
202

    
203
  private void construct(Tablebase table, int level, int numBits)
204
    {
205
    mNumBits = numBits;
206
    mLevel = level;
207
    int size = table.getSize();
208
    int entrySize = (1<<numBits);
209
    mNumBuckets = (size+entrySize-1)/entrySize;
210
    int[] bucket = new int[mNumBuckets];
211

    
212
    for(int i=0; i<size; i++)
213
      if( table.retrieveUnpacked(i)==level )
214
        {
215
        int currBucket = i/entrySize;
216
        bucket[currBucket]++;
217
        }
218

    
219
    int totalBytesForTable = 0;
220

    
221
    for(int i=0; i<mNumBuckets; i++)
222
      {
223
      int numBytes = (bucket[i]*numBits+7)/8;
224
      bucket[i] = totalBytesForTable;
225
      totalBytesForTable+=numBytes;
226
      }
227

    
228
    mBucketBytes = numBytesForIndices(totalBytesForTable);
229
    int totalBytesForBuckets = mNumBuckets*mBucketBytes;
230
    int totalSize = HEADER_SIZE + totalBytesForBuckets + totalBytesForTable;
231
    mData = new byte[totalSize];
232

    
233
    android.util.Log.e("D", "Constructing a pruningTable, numBits="+numBits+" size "+totalSize);
234

    
235
    writeHeader();
236
    writeBucketPointers(bucket);
237
    writeBuckets(table);
238
    }
239

    
240
///////////////////////////////////////////////////////////////////////////////////////////////////
241
// PUBLIC API
242
///////////////////////////////////////////////////////////////////////////////////////////////////
243
// only numBits = 4,8,12,16 actually supported
244

    
245
  PruningTable(Tablebase table, int level, int numBits)
246
    {
247
    construct(table,level,numBits);
248
    }
249

    
250
///////////////////////////////////////////////////////////////////////////////////////////////////
251

    
252
  PruningTable(Tablebase table, int level)
253
    {
254
    int[] supported = {4,8,12,16};
255
    int minimum = 0, chosen =  0;
256
    int size = retNumPositions(table,level);
257
    int tableSize = table.getSize();
258

    
259
    android.util.Log.e("D", "----- PRUNING TABLE LEVEL "+level+" ------");
260

    
261
    for(int i=0; i<supported.length; i++)
262
      {
263
      int approx = computeApproximateSize( tableSize, size, supported[i]);
264

    
265
      if( i==0 || approx<minimum )
266
        {
267
        minimum = approx;
268
        chosen  = i;
269
        }
270

    
271
      android.util.Log.e("D", "trying numBits: "+supported[i]+" approx size: "+approx);
272
      }
273

    
274
    construct(table,level,supported[chosen]);
275
    }
276

    
277
///////////////////////////////////////////////////////////////////////////////////////////////////
278

    
279
  PruningTable(byte[] data)
280
    {
281
    mData = data;
282

    
283
    mLevel       = ((int)mData[0])&0xff;
284
    mBucketBytes = ((int)mData[1])&0xff;
285
    mNumBits     = ((int)mData[2])&0xff;
286
    mNumBuckets  =(((int)mData[3])&0xff)*256 + ((int)mData[4])&0xff;
287
    }
288

    
289
///////////////////////////////////////////////////////////////////////////////////////////////////
290

    
291
  byte[] getPacked()
292
    {
293
    return mData;
294
    }
295

    
296
///////////////////////////////////////////////////////////////////////////////////////////////////
297

    
298
  int getLevel()
299
    {
300
    return mLevel;
301
    }
302

    
303
///////////////////////////////////////////////////////////////////////////////////////////////////
304

    
305
  boolean belongs(int number)
306
    {
307
    int bucket = (number>>mNumBits);
308
    int offset1 = getBucketOffset(bucket);
309
    int offset2 = getBucketOffset(bucket+1);
310
    int numBytesInBucket = offset2-offset1;
311
    int numEntriesInBucket = (8*numBytesInBucket)/mNumBits;
312

    
313
    if( mNumBits==4 && numBytesInBucket>0 )
314
      {
315
      byte val = mData[offset2-1];
316
      if( (val/16)*16 == val ) numEntriesInBucket--;
317
      }
318

    
319
    int lower = 0;
320
    int upper = numEntriesInBucket-1;
321

    
322
    while( upper>=lower )
323
      {
324
      int mid = (lower+upper)/2;
325
      int value = retrieveData(offset1,mid);
326

    
327
      value += (bucket<<mNumBits);
328
      if( value==number ) return true;
329

    
330
      if( value>number ) upper = mid-1;
331
      else               lower = mid+1;
332
      }
333

    
334
    return false;
335
    }
336
}
(2-2/12)