Files
SoundTool/Assets/Scripts/VoxelOctree/GPU/VoxelRaycastGpuManager.cs

263 lines
9.5 KiB
C#
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using UnityEngine;
using System.Runtime.InteropServices;
using System.Diagnostics;
public class VoxelRaycastGpuManager
{
ComputeShader raycastShader;
//---------------- Octree ----------------------------------------------------
OctreeNode root; // assign your built octree root
ComputeBuffer nodeBuffer;
public LinearTree linearTree;
//--------------- RayCasts ------------------------------------------------------
int kernel;
ComputeBuffer datasBuffer = null;
ComputeBuffer hitCounterBuffer = null;
ComputeBuffer rayBuffer = null;
ComputeBuffer hitBuffer = null;
ComputeBuffer countBuffer = null;
int raysPerBatch;
int batchDataClassSize = Marshal.SizeOf(typeof(VoxelRaycastGPU.BatchData));
int groupsX;
int maxRaycastPerIteration;
int maxIterations = 3;
int threadsY = 8;
int maxGroupsY = 65535;
//--------------- Clustering ------------------------------------------------------
ComputeShader clusteringShader;
int accumulationKernel;
int reductionKernel;
int gridSize;
int threadCount = 64;
int groupCountHits;
int groupCountGrid;
ComputeBuffer cellClusteringSums;
ComputeBuffer cellClusteringCounts;
ComputeBuffer clustered;
//---------------------------------------------------------------------------------
public VoxelRaycastGpuManager(ComputeShader computeShader, ComputeShader clusteringShader, OctreeNode octreeRoot)
{
raycastShader = computeShader;
this.clusteringShader = clusteringShader;
root = octreeRoot;
}
public VoxelRaycastGPU.BatchData[] Raycast(in VoxelRaycastGPU.BatchData[] batchData, int datasLenght)
{
int totalCastNumber = 0;
Stopwatch sw = Stopwatch.StartNew();
int iteration = 0;
int currentCount = batchData.Length;
int previousCount = currentCount;
datasBuffer.SetCounterValue(0);
datasBuffer.SetData(batchData, 0, 0, currentCount);
while (iteration < maxIterations && currentCount > 0)
{
totalCastNumber += currentCount;
previousCount = currentCount;
hitBuffer.SetCounterValue(0);
datasBuffer.SetCounterValue(0);
raycastShader.SetBuffer(kernel, "batchDatas", datasBuffer );
raycastShader.SetBuffer(kernel, "hits", hitBuffer);
for (int y = 0; y < currentCount; y += threadsY * maxGroupsY)
{
int remaining = currentCount - y;
int dispatchGroupsY = Mathf.CeilToInt(Mathf.Min(remaining / (float)threadsY, maxGroupsY));
raycastShader.SetInt("startIndexY", y);
raycastShader.Dispatch(kernel, groupsX, dispatchGroupsY, 1);
}
ComputeBuffer.CopyCount(hitBuffer, countBuffer, 0);
int[] countArr = new int[1];
countBuffer.GetData(countArr);
currentCount = countArr[0];
sw.Stop();
VoxelRaycastGPU.BatchData[] hits = new VoxelRaycastGPU.BatchData[currentCount];
hitBuffer.GetData(hits, 0, 0, currentCount);
for( int i = 0; i < hits.Length; i++ )
{
GameObject sphere = GameObject.CreatePrimitive(PrimitiveType.Sphere);
sphere.transform.position = hits[i].origin;
sphere.transform.localScale = Vector3.one * 0.5f;
}
sw.Start();
iteration++;
if (currentCount > 0 && iteration < maxIterations )
{
(datasBuffer,hitBuffer) = (hitBuffer,datasBuffer);
}
}
sw.Stop();
VoxelRaycastGPU.BatchData[] result = new VoxelRaycastGPU.BatchData[previousCount];
hitBuffer.GetData(result, 0, 0, previousCount);
UnityEngine.Debug.Log($"Raycast done in {sw.Elapsed.TotalMilliseconds}ms for a total of {totalCastNumber} raycasts");
return result;
}
public void Init(int nbRaysPerBatch, in VoxelRaycastGPU.Ray[] rays, int maxIterations)
{
maxRaycastPerIteration = 1000000;
raysPerBatch = nbRaysPerBatch;
this.maxIterations = maxIterations;
countBuffer = new ComputeBuffer(1, sizeof(int), ComputeBufferType.Raw);
// Flatten octree
linearTree = OctreeGpuHelpers.FlattenOctree(root);
int nodeStride = Marshal.SizeOf(typeof(LinearNode)); // should be 64
hitBuffer = new ComputeBuffer(maxRaycastPerIteration * raysPerBatch, batchDataClassSize, ComputeBufferType.Append);
datasBuffer = new ComputeBuffer(maxRaycastPerIteration, batchDataClassSize, ComputeBufferType.Append);
rayBuffer = new ComputeBuffer(rays.Length, Marshal.SizeOf(typeof(VoxelRaycastGPU.Ray)), ComputeBufferType.Default);
rayBuffer.SetData(rays, 0, 0, rays.Length);
// Create GPU buffer for nodes
nodeBuffer = new ComputeBuffer(linearTree.nodes.Length, nodeStride, ComputeBufferType.Default);
nodeBuffer.SetData(linearTree.nodes);
hitCounterBuffer = new ComputeBuffer(1, sizeof(int), ComputeBufferType.Raw);
uint[] counterInit = { 0 };
counterInit[0] = 0;
hitCounterBuffer.SetData(counterInit);
kernel = raycastShader.FindKernel("CSMain");
raycastShader.SetBuffer(kernel, "nodes", nodeBuffer);
raycastShader.SetBuffer(kernel, "hitCount", hitCounterBuffer);
raycastShader.SetBuffer(kernel, "rays", rayBuffer);
raycastShader.SetInt("raysPerBatch", nbRaysPerBatch);
raycastShader.SetInt("rootIndex", linearTree.rootIndex);
raycastShader.SetInt("nodeCount", linearTree.nodes.Length);
raycastShader.SetFloat("rootHalfSize", root.bounds.size.x / 2f);
raycastShader.SetFloats("rootCenter", new float[3] { root.bounds.center.x, root.bounds.center.y, root.bounds.center.z });
groupsX = Mathf.CeilToInt((float)raysPerBatch / 8);
gridSize = maxRaycastPerIteration / raysPerBatch;
groupCountGrid = Mathf.CeilToInt((float)gridSize / threadCount);
cellClusteringSums = new ComputeBuffer(gridSize, sizeof(float) * 4);
cellClusteringCounts = new ComputeBuffer(gridSize, sizeof(uint));
clustered = new ComputeBuffer(gridSize, Marshal.SizeOf(typeof(VoxelRaycastGPU.BatchData)), ComputeBufferType.Append);
accumulationKernel = clusteringShader.FindKernel("Accumulate");
reductionKernel = clusteringShader.FindKernel("Reduce");
}
public int Clustering( int targetGroupCount )
{
int count = datasBuffer.count;
// 1⃣ Définir la résolution de la grille
int cellsPerAxis = Mathf.CeilToInt(Mathf.Pow(targetGroupCount, 1f / 3f));
Vector3 gridMin = new Vector3(-50, -50, -50);
Vector3 gridMax = new Vector3(50, 50, 50);
Vector3 cellSize = (gridMax - gridMin) / cellsPerAxis;
int totalCells = cellsPerAxis * cellsPerAxis * cellsPerAxis;
// 2⃣ Créer les buffers
ComputeBuffer cellCount = new ComputeBuffer(totalCells, sizeof(uint));
ComputeBuffer cellDistanceSum = new ComputeBuffer(totalCells, sizeof(float));
ComputeBuffer resultBuffer = new ComputeBuffer(totalCells, sizeof(float) * 4, ComputeBufferType.Append);
resultBuffer.SetCounterValue(0);
clusteringShader.SetBuffer(accumulationKernel, "batchDatas", datasBuffer);
clusteringShader.SetBuffer(accumulationKernel, "cellCount", cellCount);
clusteringShader.SetBuffer(accumulationKernel, "cellDistanceSum", cellDistanceSum);
clusteringShader.SetVector("gridMin", gridMin);
clusteringShader.SetVector("gridMax", gridMax);
clusteringShader.SetInts("gridResolution", cellsPerAxis, cellsPerAxis, cellsPerAxis);
clusteringShader.SetVector("cellSize", cellSize);
int threadsPerGroup = 64;
int maxThreadGroups = 65535;
for (int i = 0; i < count; i += threadsPerGroup * maxThreadGroups)
{
int remaining = count - i;
int groups = Mathf.CeilToInt(Mathf.Min(remaining / (float)threadsPerGroup, maxThreadGroups));
clusteringShader.SetInt("startIndex", i);
clusteringShader.Dispatch(accumulationKernel, groups, 1, 1);
}
// 4⃣ Dispatch reduce
clusteringShader.SetBuffer(reductionKernel, "cellCount", cellCount);
clusteringShader.SetBuffer(reductionKernel, "cellDistanceSum", cellDistanceSum);
clusteringShader.SetBuffer(reductionKernel, "clusteredBatches", resultBuffer);
int reduceGroups = Mathf.CeilToInt(totalCells / 64);
clusteringShader.Dispatch(reductionKernel, reduceGroups, 1, 1);
// 5⃣ Lire le résultat
ComputeBuffer.CopyCount(resultBuffer, countBuffer, 0);
int[] countArray = new int[1];
countBuffer.GetData(countArray);
int outputCount = countArray[0];
VoxelRaycastGPU.BatchData[] finalBatches = new VoxelRaycastGPU.BatchData[outputCount];
resultBuffer.GetData(finalBatches, 0, 0, outputCount);
// 🔚 Cleanup
cellCount.Release();
cellDistanceSum.Release();
resultBuffer.Release();
return outputCount;
}
~VoxelRaycastGpuManager()
{
if (hitCounterBuffer != null)
hitCounterBuffer.Release();
if (rayBuffer != null)
rayBuffer.Release();
if (hitBuffer != null)
hitBuffer.Release();
if (datasBuffer != null)
datasBuffer.Release();
if( countBuffer != null )
countBuffer.Release();
}
}