mirror of
				https://git.suyu.dev/suyu/suyu.git
				synced 2025-10-26 04:17:12 +08:00 
			
		
		
		
	Merge pull request #3979 from ReinUsesLisp/thread-group
shader/other: Implement thread comparisons (NV_shader_thread_group)
This commit is contained in:
		
						commit
						487dd05170
					
				| @ -2309,6 +2309,18 @@ private: | ||||
|         return {"gl_SubGroupInvocationARB", Type::Uint}; | ||||
|     } | ||||
| 
 | ||||
|     template <const std::string_view& comparison> | ||||
|     Expression ThreadMask(Operation) { | ||||
|         if (device.HasWarpIntrinsics()) { | ||||
|             return {fmt::format("gl_Thread{}MaskNV", comparison), Type::Uint}; | ||||
|         } | ||||
|         if (device.HasShaderBallot()) { | ||||
|             return {fmt::format("uint(gl_SubGroup{}MaskARB)", comparison), Type::Uint}; | ||||
|         } | ||||
|         LOG_ERROR(Render_OpenGL, "Thread mask intrinsics are required by the shader"); | ||||
|         return {"0U", Type::Uint}; | ||||
|     } | ||||
| 
 | ||||
|     Expression ShuffleIndexed(Operation operation) { | ||||
|         std::string value = VisitOperand(operation, 0).AsFloat(); | ||||
| 
 | ||||
| @ -2337,6 +2349,12 @@ private: | ||||
|         static constexpr std::string_view NotEqual = "!="; | ||||
|         static constexpr std::string_view GreaterEqual = ">="; | ||||
| 
 | ||||
|         static constexpr std::string_view Eq = "Eq"; | ||||
|         static constexpr std::string_view Ge = "Ge"; | ||||
|         static constexpr std::string_view Gt = "Gt"; | ||||
|         static constexpr std::string_view Le = "Le"; | ||||
|         static constexpr std::string_view Lt = "Lt"; | ||||
| 
 | ||||
|         static constexpr std::string_view Add = "Add"; | ||||
|         static constexpr std::string_view Min = "Min"; | ||||
|         static constexpr std::string_view Max = "Max"; | ||||
| @ -2554,6 +2572,11 @@ private: | ||||
|         &GLSLDecompiler::VoteEqual, | ||||
| 
 | ||||
|         &GLSLDecompiler::ThreadId, | ||||
|         &GLSLDecompiler::ThreadMask<Func::Eq>, | ||||
|         &GLSLDecompiler::ThreadMask<Func::Ge>, | ||||
|         &GLSLDecompiler::ThreadMask<Func::Gt>, | ||||
|         &GLSLDecompiler::ThreadMask<Func::Le>, | ||||
|         &GLSLDecompiler::ThreadMask<Func::Lt>, | ||||
|         &GLSLDecompiler::ShuffleIndexed, | ||||
| 
 | ||||
|         &GLSLDecompiler::MemoryBarrierGL, | ||||
|  | ||||
| @ -515,6 +515,16 @@ private: | ||||
|     void DeclareCommon() { | ||||
|         thread_id = | ||||
|             DeclareInputBuiltIn(spv::BuiltIn::SubgroupLocalInvocationId, t_in_uint, "thread_id"); | ||||
|         thread_masks[0] = | ||||
|             DeclareInputBuiltIn(spv::BuiltIn::SubgroupEqMask, t_in_uint4, "thread_eq_mask"); | ||||
|         thread_masks[1] = | ||||
|             DeclareInputBuiltIn(spv::BuiltIn::SubgroupGeMask, t_in_uint4, "thread_ge_mask"); | ||||
|         thread_masks[2] = | ||||
|             DeclareInputBuiltIn(spv::BuiltIn::SubgroupGtMask, t_in_uint4, "thread_gt_mask"); | ||||
|         thread_masks[3] = | ||||
|             DeclareInputBuiltIn(spv::BuiltIn::SubgroupLeMask, t_in_uint4, "thread_le_mask"); | ||||
|         thread_masks[4] = | ||||
|             DeclareInputBuiltIn(spv::BuiltIn::SubgroupLtMask, t_in_uint4, "thread_lt_mask"); | ||||
|     } | ||||
| 
 | ||||
|     void DeclareVertex() { | ||||
| @ -2175,6 +2185,13 @@ private: | ||||
|         return {OpLoad(t_uint, thread_id), Type::Uint}; | ||||
|     } | ||||
| 
 | ||||
|     template <std::size_t index> | ||||
|     Expression ThreadMask(Operation) { | ||||
|         // TODO(Rodrigo): Handle devices with different warp sizes
 | ||||
|         const Id mask = thread_masks[index]; | ||||
|         return {OpLoad(t_uint, AccessElement(t_in_uint, mask, 0)), Type::Uint}; | ||||
|     } | ||||
| 
 | ||||
|     Expression ShuffleIndexed(Operation operation) { | ||||
|         const Id value = AsFloat(Visit(operation[0])); | ||||
|         const Id index = AsUint(Visit(operation[1])); | ||||
| @ -2639,6 +2656,11 @@ private: | ||||
|         &SPIRVDecompiler::Vote<&Module::OpSubgroupAllEqualKHR>, | ||||
| 
 | ||||
|         &SPIRVDecompiler::ThreadId, | ||||
|         &SPIRVDecompiler::ThreadMask<0>, // Eq
 | ||||
|         &SPIRVDecompiler::ThreadMask<1>, // Ge
 | ||||
|         &SPIRVDecompiler::ThreadMask<2>, // Gt
 | ||||
|         &SPIRVDecompiler::ThreadMask<3>, // Le
 | ||||
|         &SPIRVDecompiler::ThreadMask<4>, // Lt
 | ||||
|         &SPIRVDecompiler::ShuffleIndexed, | ||||
| 
 | ||||
|         &SPIRVDecompiler::MemoryBarrierGL, | ||||
| @ -2763,6 +2785,7 @@ private: | ||||
|     Id workgroup_id{}; | ||||
|     Id local_invocation_id{}; | ||||
|     Id thread_id{}; | ||||
|     std::array<Id, 5> thread_masks{}; // eq, ge, gt, le, lt
 | ||||
| 
 | ||||
|     VertexIndices in_indices; | ||||
|     VertexIndices out_indices; | ||||
|  | ||||
| @ -109,6 +109,27 @@ u32 ShaderIR::DecodeOther(NodeBlock& bb, u32 pc) { | ||||
|                 return Operation(OperationCode::WorkGroupIdY); | ||||
|             case SystemVariable::CtaIdZ: | ||||
|                 return Operation(OperationCode::WorkGroupIdZ); | ||||
|             case SystemVariable::EqMask: | ||||
|             case SystemVariable::LtMask: | ||||
|             case SystemVariable::LeMask: | ||||
|             case SystemVariable::GtMask: | ||||
|             case SystemVariable::GeMask: | ||||
|                 uses_warps = true; | ||||
|                 switch (instr.sys20) { | ||||
|                 case SystemVariable::EqMask: | ||||
|                     return Operation(OperationCode::ThreadEqMask); | ||||
|                 case SystemVariable::LtMask: | ||||
|                     return Operation(OperationCode::ThreadLtMask); | ||||
|                 case SystemVariable::LeMask: | ||||
|                     return Operation(OperationCode::ThreadLeMask); | ||||
|                 case SystemVariable::GtMask: | ||||
|                     return Operation(OperationCode::ThreadGtMask); | ||||
|                 case SystemVariable::GeMask: | ||||
|                     return Operation(OperationCode::ThreadGeMask); | ||||
|                 default: | ||||
|                     UNREACHABLE(); | ||||
|                     return Immediate(0u); | ||||
|                 } | ||||
|             default: | ||||
|                 UNIMPLEMENTED_MSG("Unhandled system move: {}", | ||||
|                                   static_cast<u32>(instr.sys20.Value())); | ||||
|  | ||||
| @ -226,6 +226,11 @@ enum class OperationCode { | ||||
|     VoteEqual,    /// (bool) -> bool
 | ||||
| 
 | ||||
|     ThreadId,       /// () -> uint
 | ||||
|     ThreadEqMask,   /// () -> uint
 | ||||
|     ThreadGeMask,   /// () -> uint
 | ||||
|     ThreadGtMask,   /// () -> uint
 | ||||
|     ThreadLeMask,   /// () -> uint
 | ||||
|     ThreadLtMask,   /// () -> uint
 | ||||
|     ShuffleIndexed, /// (uint value, uint index) -> uint
 | ||||
| 
 | ||||
|     MemoryBarrierGL, /// () -> void
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user