using UnityEngine; using System.Runtime.InteropServices; using System.Diagnostics; public class VoxelRaycastGpuManager { ComputeShader raycastShader; //---------------- Octree ---------------------------------------------------- OctreeNode root; // assign your built octree root ComputeBuffer nodeBuffer; public LinearTree linearTree; //--------------- RayCasts ------------------------------------------------------ int kernel; ComputeBuffer datasBuffer = null; ComputeBuffer hitCounterBuffer = null; ComputeBuffer rayBuffer = null; ComputeBuffer hitBuffer = null; ComputeBuffer countBuffer = null; int raysPerBatch; int batchDataClassSize = Marshal.SizeOf(typeof(VoxelRaycastGPU.BatchData)); int groupsX; int maxRaycastPerIteration; int threadsY = 8; int maxGroupsY = 65535; //--------------- Clustering ------------------------------------------------------ ComputeShader clusteringShader; int accumulationKernel; int reductionKernel; int gridSize; int threadCount = 64; int groupCountHits; int groupCountGrid; ComputeBuffer cellClusteringSums; ComputeBuffer cellClusteringCounts; ComputeBuffer clustered; //--------------------------------------------------------------------------------- public VoxelRaycastGpuManager(ComputeShader computeShader, ComputeShader clusteringShader, OctreeNode octreeRoot) { raycastShader = computeShader; this.clusteringShader = clusteringShader; root = octreeRoot; } public VoxelRaycastGPU.BatchData[] Raycast(in VoxelRaycastGPU.BatchData[] batchData, int datasLenght) { int iteration = 0; int currentCount = datasLenght; int previousCount = datasLenght; datasBuffer.SetData(batchData, 0, 0, currentCount); while (iteration < 5 && currentCount > 0) { previousCount = currentCount; hitBuffer.SetCounterValue(0); raycastShader.SetBuffer(kernel, "batchDatas", datasBuffer ); raycastShader.SetBuffer(kernel, "hits", hitBuffer); /** Stopwatch sw = Stopwatch.StartNew(); */ for (int y = 0; y < currentCount; y += threadsY * maxGroupsY) { int remaining = currentCount - y; int dispatchGroupsY = Mathf.CeilToInt(Mathf.Min(remaining / (float)threadsY, maxGroupsY)); raycastShader.SetInt("startIndexY", y); raycastShader.Dispatch(kernel, groupsX, dispatchGroupsY, 1); } ComputeBuffer.CopyCount(hitBuffer, countBuffer, 0); int[] countArr = new int[1]; countBuffer.GetData(countArr); currentCount = countArr[0]; /** VoxelRaycastGPU.BatchData[] hits = new VoxelRaycastGPU.BatchData[currentCount]; hitBuffer.GetData(hits, 0, 0, currentCount); for( int i = 0; i < hits.Length; i++ ) { GameObject sphere = GameObject.CreatePrimitive(PrimitiveType.Sphere); sphere.transform.position = hits[i].origin; sphere.transform.localScale = Vector3.one * 0.5f; } */ /** sw.Stop(); UnityEngine.Debug.Log($"Dispatch done in {sw.Elapsed.TotalMilliseconds}ms for {previousCount*raysPerBatch} casts retrieving {currentCount} hits"); */ iteration++; if (currentCount > 0 && iteration < 5) { datasBuffer = hitBuffer; /** if (currentCount * raysPerBatch > maxRaycastPerIteration && iteration < 5) { sw = Stopwatch.StartNew(); currentCount = Clustering( maxRaycastPerIteration / raysPerBatch); sw.Stop(); UnityEngine.Debug.Log($"Clustering done in {sw.Elapsed.TotalMilliseconds}ms for {currentCount} casts"); VoxelRaycastGPU.BatchData[] hits = new VoxelRaycastGPU.BatchData[currentCount]; datasBuffer.GetData(hits, 0, 0, currentCount); for (int i = 0; i < currentCount; i++) { GameObject sphere = GameObject.CreatePrimitive(PrimitiveType.Sphere); sphere.transform.position = hits[i].origin; sphere.transform.localScale = Vector3.one * 0.5f; } } */ } } VoxelRaycastGPU.BatchData[] result = new VoxelRaycastGPU.BatchData[previousCount]; hitBuffer.GetData(result, 0, 0, previousCount); return result; } public void Init(int nbRaysPerBatch, in VoxelRaycastGPU.Ray[] rays) { maxRaycastPerIteration = 1000000; raysPerBatch = nbRaysPerBatch; countBuffer = new ComputeBuffer(1, sizeof(int), ComputeBufferType.Raw); // Flatten octree linearTree = OctreeGpuHelpers.FlattenOctree(root); int nodeStride = Marshal.SizeOf(typeof(LinearNode)); // should be 64 hitBuffer = new ComputeBuffer(maxRaycastPerIteration * raysPerBatch, batchDataClassSize, ComputeBufferType.Append); datasBuffer = new ComputeBuffer(maxRaycastPerIteration, batchDataClassSize, ComputeBufferType.Default); rayBuffer = new ComputeBuffer(rays.Length, Marshal.SizeOf(typeof(VoxelRaycastGPU.Ray)), ComputeBufferType.Default); rayBuffer.SetData(rays, 0, 0, rays.Length); // Create GPU buffer for nodes nodeBuffer = new ComputeBuffer(linearTree.nodes.Length, nodeStride, ComputeBufferType.Default); nodeBuffer.SetData(linearTree.nodes); hitCounterBuffer = new ComputeBuffer(1, sizeof(int), ComputeBufferType.Raw); uint[] counterInit = { 0 }; counterInit[0] = 0; hitCounterBuffer.SetData(counterInit); kernel = raycastShader.FindKernel("CSMain"); raycastShader.SetBuffer(kernel, "nodes", nodeBuffer); raycastShader.SetBuffer(kernel, "hitCount", hitCounterBuffer); raycastShader.SetBuffer(kernel, "rays", rayBuffer); raycastShader.SetInt("raysPerBatch", nbRaysPerBatch); raycastShader.SetInt("rootIndex", linearTree.rootIndex); raycastShader.SetInt("nodeCount", linearTree.nodes.Length); raycastShader.SetFloat("rootHalfSize", root.bounds.size.x / 2f); raycastShader.SetFloats("rootCenter", new float[3] { root.bounds.center.x, root.bounds.center.y, root.bounds.center.z }); groupsX = Mathf.CeilToInt((float)raysPerBatch / 8); gridSize = maxRaycastPerIteration / raysPerBatch; groupCountGrid = Mathf.CeilToInt((float)gridSize / threadCount); cellClusteringSums = new ComputeBuffer(gridSize, sizeof(float) * 4); cellClusteringCounts = new ComputeBuffer(gridSize, sizeof(uint)); clustered = new ComputeBuffer(gridSize, Marshal.SizeOf(typeof(VoxelRaycastGPU.BatchData)), ComputeBufferType.Append); accumulationKernel = clusteringShader.FindKernel("Accumulate"); reductionKernel = clusteringShader.FindKernel("Reduce"); } public int Clustering( int targetGroupCount ) { int count = datasBuffer.count; // 1️⃣ Définir la résolution de la grille int cellsPerAxis = Mathf.CeilToInt(Mathf.Pow(targetGroupCount, 1f / 3f)); Vector3 gridMin = new Vector3(-50, -50, -50); Vector3 gridMax = new Vector3(50, 50, 50); Vector3 cellSize = (gridMax - gridMin) / cellsPerAxis; int totalCells = cellsPerAxis * cellsPerAxis * cellsPerAxis; // 2️⃣ Créer les buffers ComputeBuffer cellCount = new ComputeBuffer(totalCells, sizeof(uint)); ComputeBuffer cellDistanceSum = new ComputeBuffer(totalCells, sizeof(float)); ComputeBuffer resultBuffer = new ComputeBuffer(totalCells, sizeof(float) * 4, ComputeBufferType.Append); resultBuffer.SetCounterValue(0); clusteringShader.SetBuffer(accumulationKernel, "batchDatas", datasBuffer); clusteringShader.SetBuffer(accumulationKernel, "cellCount", cellCount); clusteringShader.SetBuffer(accumulationKernel, "cellDistanceSum", cellDistanceSum); clusteringShader.SetVector("gridMin", gridMin); clusteringShader.SetVector("gridMax", gridMax); clusteringShader.SetInts("gridResolution", cellsPerAxis, cellsPerAxis, cellsPerAxis); clusteringShader.SetVector("cellSize", cellSize); int threadsPerGroup = 64; int maxThreadGroups = 65535; for (int i = 0; i < count; i += threadsPerGroup * maxThreadGroups) { int remaining = count - i; int groups = Mathf.CeilToInt(Mathf.Min(remaining / (float)threadsPerGroup, maxThreadGroups)); clusteringShader.SetInt("startIndex", i); clusteringShader.Dispatch(accumulationKernel, groups, 1, 1); } // 4️⃣ Dispatch reduce clusteringShader.SetBuffer(reductionKernel, "cellCount", cellCount); clusteringShader.SetBuffer(reductionKernel, "cellDistanceSum", cellDistanceSum); clusteringShader.SetBuffer(reductionKernel, "clusteredBatches", resultBuffer); int reduceGroups = Mathf.CeilToInt(totalCells / 64); clusteringShader.Dispatch(reductionKernel, reduceGroups, 1, 1); // 5️⃣ Lire le résultat ComputeBuffer.CopyCount(resultBuffer, countBuffer, 0); int[] countArray = new int[1]; countBuffer.GetData(countArray); int outputCount = countArray[0]; VoxelRaycastGPU.BatchData[] finalBatches = new VoxelRaycastGPU.BatchData[outputCount]; resultBuffer.GetData(finalBatches, 0, 0, outputCount); // 🔚 Cleanup cellCount.Release(); cellDistanceSum.Release(); resultBuffer.Release(); return outputCount; } ~VoxelRaycastGpuManager() { if (hitCounterBuffer != null) hitCounterBuffer.Release(); if (rayBuffer != null) rayBuffer.Release(); if (hitBuffer != null) hitBuffer.Release(); if (datasBuffer != null) datasBuffer.Release(); if( countBuffer != null ) countBuffer.Release(); } }